From f93ec50424d0f469e273c44c79e5b520b2c35d5b Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Wed, 7 Aug 2019 16:39:32 +0200 Subject: [PATCH] Compare Queryable operators by MethodInfo Part of #16300, not exhaustive --- ...lationalSqlTranslatingExpressionVisitor.cs | 12 +- ...ntityEqualityRewritingExpressionVisitor.cs | 69 +-- .../NavigationExpandingExpressionVisitor.cs | 163 +++--- src/EFCore/Query/QueryableMethodProvider.cs | 17 +- ...yableMethodTranslatingExpressionVisitor.cs | 489 ++++++++---------- 5 files changed, 370 insertions(+), 380 deletions(-) diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index f7dad2d5461..082f2c669ee 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Linq; using System.Linq.Expressions; +using System.Reflection; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; @@ -316,10 +317,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return null; } - if ((methodCallExpression.Method.Name == nameof(Queryable.Any) - || methodCallExpression.Method.Name == nameof(Queryable.All) - || methodCallExpression.Method.Name == nameof(Queryable.Contains)) - && subquery.Tables.Count == 0) + if (subquery.Tables.Count == 0 + && methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo genericMethod + && (genericMethod == QueryableMethodProvider.AnyWithoutPredicateMethodInfo + || genericMethod == QueryableMethodProvider.AnyWithPredicateMethodInfo + || genericMethod == QueryableMethodProvider.AllMethodInfo + || genericMethod == QueryableMethodProvider.ContainsMethodInfo)) { return subquery.Projection[0].Expression; } diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs index fc31230e3f1..e813cfcd09c 100644 --- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs @@ -188,6 +188,7 @@ protected override Expression VisitConditional(ConditionalExpression conditional protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { var method = methodCallExpression.Method; + var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; var arguments = methodCallExpression.Arguments; Expression newSource; @@ -219,37 +220,48 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp : newMethodCall; } - if (method.DeclaringType == typeof(Queryable) || method.DeclaringType == typeof(QueryableExtensions)) + switch (method.Name) { - switch (method.Name) - { - // These are methods that require special handling - case nameof(Queryable.Contains) when arguments.Count == 2: - return VisitContainsMethodCall(methodCallExpression); - - case nameof(Queryable.OrderBy) when arguments.Count == 2: - case nameof(Queryable.OrderByDescending) when arguments.Count == 2: - case nameof(Queryable.ThenBy) when arguments.Count == 2: - case nameof(Queryable.ThenByDescending) when arguments.Count == 2: - return VisitOrderingMethodCall(methodCallExpression); - - // The following are projecting methods, which flow the entity type from *within* the lambda outside. - case nameof(Queryable.Select): - case nameof(Queryable.SelectMany): - return VisitSelectMethodCall(methodCallExpression); - - case nameof(Queryable.GroupJoin): - case nameof(Queryable.Join): - case nameof(QueryableExtensions.LeftJoin): - return VisitJoinMethodCall(methodCallExpression); - - case nameof(Queryable.GroupBy): // TODO: Implement - break; - } + // These are methods that require special handling + case nameof(Queryable.Contains) + when genericMethod == QueryableMethodProvider.ContainsMethodInfo: + return VisitContainsMethodCall(methodCallExpression); + + case nameof(Queryable.OrderBy) + when genericMethod == QueryableMethodProvider.OrderByMethodInfo: + case nameof(Queryable.OrderByDescending) + when genericMethod == QueryableMethodProvider.OrderByDescendingMethodInfo: + case nameof(Queryable.ThenBy) + when genericMethod == QueryableMethodProvider.ThenByMethodInfo: + case nameof(Queryable.ThenByDescending) when genericMethod == QueryableMethodProvider.ThenByDescendingMethodInfo: + return VisitOrderingMethodCall(methodCallExpression); + + // The following are projecting methods, which flow the entity type from *within* the lambda outside. + case nameof(Queryable.Select) + when genericMethod == QueryableMethodProvider.SelectMethodInfo: + case nameof(Queryable.SelectMany) + when genericMethod == QueryableMethodProvider.SelectManyWithoutCollectionSelectorMethodInfo || + genericMethod == QueryableMethodProvider.SelectManyWithCollectionSelectorMethodInfo: + return VisitSelectMethodCall(methodCallExpression); + + case nameof(Queryable.GroupJoin) + when genericMethod == QueryableMethodProvider.GroupJoinMethodInfo: + case nameof(Queryable.Join) + when genericMethod == QueryableMethodProvider.JoinMethodInfo: + case nameof(QueryableExtensions.LeftJoin) + when genericMethod == QueryableExtensions.LeftJoinMethodInfo: + return VisitJoinMethodCall(methodCallExpression); + + case nameof(Queryable.GroupBy) + when genericMethod == QueryableMethodProvider.GroupByWithKeySelectorMethodInfo || + genericMethod == QueryableMethodProvider.GroupByWithKeyElementSelectorMethodInfo || + genericMethod == QueryableMethodProvider.GroupByWithKeyResultSelectorMethodInfo || + genericMethod == QueryableMethodProvider.GroupByWithKeyElementResultSelectorMethodInfo: + break; // TODO: Implement } // We handled the Contains Queryable extension method above, but there's also IList.Contains - if (method.IsGenericMethod && method.GetGenericMethodDefinition().Equals(_enumerableContainsMethodInfo) + if (genericMethod == _enumerableContainsMethodInfo || method.DeclaringType.GetInterfaces().Contains(typeof(IList)) && string.Equals(method.Name, nameof(IList.Contains))) { return VisitContainsMethodCall(methodCallExpression); @@ -293,8 +305,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp if (methodCallExpression.Method.ReturnType.TryGetSequenceType() is Type returnElementType && (returnElementType == sourceElementType || sourceElementType == null)) { - return newSourceWrapper.Update( - methodCallExpression.Update(null, newArguments)); + return newSourceWrapper.Update(methodCallExpression.Update(null, newArguments)); } // If the source type is an IQueryable over the return type, this is a cardinality-reducing method (e.g. First). diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs index 893520242b8..5591ee354b4 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs @@ -151,9 +151,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp var firstArgument = Visit(methodCallExpression.Arguments[0]); if (firstArgument is NavigationExpansionExpression source) { - var genericMethod = methodCallExpression.Method.IsGenericMethod - ? methodCallExpression.Method.GetGenericMethodDefinition() - : null; + var method = methodCallExpression.Method; + var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; if (source.PendingOrderings.Any() && genericMethod != QueryableMethodProvider.ThenByMethodInfo @@ -183,13 +182,10 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Queryable.All) when genericMethod == QueryableMethodProvider.AllMethodInfo: - case nameof(Queryable.Any) when genericMethod == QueryableMethodProvider.AnyWithPredicateMethodInfo: - case nameof(Queryable.Count) when genericMethod == QueryableMethodProvider.CountWithPredicateMethodInfo: - case nameof(Queryable.LongCount) when genericMethod == QueryableMethodProvider.LongCountWithPredicateMethodInfo: return ProcessAllAnyCountLongCount( @@ -197,10 +193,16 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp genericMethod, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - case nameof(Queryable.Average): - case nameof(Queryable.Max): - case nameof(Queryable.Min): - case nameof(Queryable.Sum): + case nameof(Queryable.Average) + when QueryableMethodProvider.IsAverageMethodInfo(method): + case nameof(Queryable.Sum) + when QueryableMethodProvider.IsSumMethodInfo(method): + case nameof(Queryable.Max) + when genericMethod == QueryableMethodProvider.MaxWithoutSelectorMethodInfo || + genericMethod == QueryableMethodProvider.MaxWithSelectorMethodInfo: + case nameof(Queryable.Min) + when genericMethod == QueryableMethodProvider.MinWithoutSelectorMethodInfo || + genericMethod == QueryableMethodProvider.MinWithSelectorMethodInfo: return ProcessAverageMaxMinSum( source, methodCallExpression.Method.IsGenericMethod @@ -210,9 +212,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() : null); - case nameof(Queryable.Distinct): - case nameof(Queryable.Skip): - case nameof(Queryable.Take): + case nameof(Queryable.Distinct) + when genericMethod == QueryableMethodProvider.DistinctMethodInfo: + case nameof(Queryable.Skip) + when genericMethod == QueryableMethodProvider.SkipMethodInfo: + case nameof(Queryable.Take) + when genericMethod == QueryableMethodProvider.TakeMethodInfo: return ProcessDistinctSkipTake( source, methodCallExpression.Method.GetGenericMethodDefinition(), @@ -220,17 +225,30 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp ? methodCallExpression.Arguments[1] : null); - case nameof(Queryable.Contains): + case nameof(Queryable.Contains) + when genericMethod == QueryableMethodProvider.ContainsMethodInfo: return ProcessContains( source, methodCallExpression.Arguments[1]); - case nameof(Queryable.First): - case nameof(Queryable.FirstOrDefault): - case nameof(Queryable.Single): - case nameof(Queryable.SingleOrDefault): - case nameof(Queryable.Last): - case nameof(Queryable.LastOrDefault): + case nameof(Queryable.First) + when genericMethod == QueryableMethodProvider.FirstWithoutPredicateMethodInfo || + genericMethod == QueryableMethodProvider.FirstWithPredicateMethodInfo: + case nameof(Queryable.FirstOrDefault) + when genericMethod == QueryableMethodProvider.FirstOrDefaultWithoutPredicateMethodInfo || + genericMethod == QueryableMethodProvider.FirstOrDefaultWithPredicateMethodInfo: + case nameof(Queryable.Single) + when genericMethod == QueryableMethodProvider.SingleWithoutPredicateMethodInfo || + genericMethod == QueryableMethodProvider.SingleWithPredicateMethodInfo: + case nameof(Queryable.SingleOrDefault) + when genericMethod == QueryableMethodProvider.SingleOrDefaultWithoutPredicateMethodInfo || + genericMethod == QueryableMethodProvider.SingleOrDefaultWithPredicateMethodInfo: + case nameof(Queryable.Last) + when genericMethod == QueryableMethodProvider.LastWithoutPredicateMethodInfo || + genericMethod == QueryableMethodProvider.LastWithPredicateMethodInfo: + case nameof(Queryable.LastOrDefault) + when genericMethod == QueryableMethodProvider.LastOrDefaultWithoutPredicateMethodInfo || + genericMethod == QueryableMethodProvider.LastOrDefaultWithPredicateMethodInfo: return ProcessFirstSingleLastOrDefault( source, methodCallExpression.Method.GetGenericMethodDefinition(), @@ -239,37 +257,41 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp : null, methodCallExpression.Type); - case nameof(Queryable.Join): + case nameof(Queryable.Join) + when genericMethod == QueryableMethodProvider.JoinMethodInfo: + { + var secondArgument = Visit(methodCallExpression.Arguments[1]); + if (secondArgument is NavigationExpansionExpression innerSource) { - var secondArgument = Visit(methodCallExpression.Arguments[1]); - if (secondArgument is NavigationExpansionExpression innerSource) - { - return ProcessJoin( - source, - innerSource, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); - } + return ProcessJoin( + source, + innerSource, + methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); } break; + } - case nameof(QueryableExtensions.LeftJoin): + case nameof(QueryableExtensions.LeftJoin) + when genericMethod == QueryableExtensions.LeftJoinMethodInfo: + { + var secondArgument = Visit(methodCallExpression.Arguments[1]); + if (secondArgument is NavigationExpansionExpression innerSource) { - var secondArgument = Visit(methodCallExpression.Arguments[1]); - if (secondArgument is NavigationExpansionExpression innerSource) - { - return ProcessLeftJoin( - source, - innerSource, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); - } + return ProcessLeftJoin( + source, + innerSource, + methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); } break; + } - case nameof(Queryable.SelectMany): + case nameof(Queryable.SelectMany) + when genericMethod == QueryableMethodProvider.SelectManyWithoutCollectionSelectorMethodInfo || + genericMethod == QueryableMethodProvider.SelectManyWithCollectionSelectorMethodInfo: return ProcessSelectMany( source, methodCallExpression.Method.GetGenericMethodDefinition(), @@ -278,31 +300,35 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp ? methodCallExpression.Arguments[2].UnwrapLambdaFromQuote() : null); - case nameof(Queryable.Concat): - case nameof(Queryable.Except): - case nameof(Queryable.Intersect): - case nameof(Queryable.Union): + case nameof(Queryable.Concat) + when genericMethod == QueryableMethodProvider.ConcatMethodInfo: + case nameof(Queryable.Except) + when genericMethod == QueryableMethodProvider.ExceptMethodInfo: + case nameof(Queryable.Intersect) + when genericMethod == QueryableMethodProvider.IntersectMethodInfo: + case nameof(Queryable.Union) + when genericMethod == QueryableMethodProvider.UnionMethodInfo: + { + var secondArgument = Visit(methodCallExpression.Arguments[1]); + if (secondArgument is NavigationExpansionExpression innerSource) { - var secondArgument = Visit(methodCallExpression.Arguments[1]); - if (secondArgument is NavigationExpansionExpression innerSource) - { - return ProcessSetOperation( - source, - methodCallExpression.Method.GetGenericMethodDefinition(), - innerSource); - } + return ProcessSetOperation( + source, + methodCallExpression.Method.GetGenericMethodDefinition(), + innerSource); } break; + } - case nameof(Queryable.Cast): - case nameof(Queryable.OfType): + case nameof(Queryable.Cast) + when genericMethod == QueryableMethodProvider.CastMethodInfo: + case nameof(Queryable.OfType) + when genericMethod == QueryableMethodProvider.OfTypeMethodInfo: return ProcessCastOfType( source, methodCallExpression.Method.GetGenericMethodDefinition(), methodCallExpression.Type.TryGetSequenceType()); - - case nameof(EntityFrameworkQueryableExtensions.Include): case nameof(EntityFrameworkQueryableExtensions.ThenInclude): return ProcessInclude( @@ -343,29 +369,34 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp null, methodCallExpression.Arguments[2].UnwrapLambdaFromQuote()); - case nameof(Queryable.OrderBy): - case nameof(Queryable.OrderByDescending): + case nameof(Queryable.OrderBy) + when genericMethod == QueryableMethodProvider.OrderByMethodInfo: + case nameof(Queryable.OrderByDescending) + when genericMethod == QueryableMethodProvider.OrderByDescendingMethodInfo: return ProcessOrderByThenBy( source, methodCallExpression.Method.GetGenericMethodDefinition(), methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false); - case nameof(Queryable.ThenBy): - case nameof(Queryable.ThenByDescending): + case nameof(Queryable.ThenBy) + when genericMethod == QueryableMethodProvider.ThenByMethodInfo: + case nameof(Queryable.ThenByDescending) + when genericMethod == QueryableMethodProvider.ThenByDescendingMethodInfo: return ProcessOrderByThenBy( source, methodCallExpression.Method.GetGenericMethodDefinition(), methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true); - - case nameof(Queryable.Select): + case nameof(Queryable.Select) + when genericMethod == QueryableMethodProvider.SelectMethodInfo: return ProcessSelect( source, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - case nameof(Queryable.Where): + case nameof(Queryable.Where) + when genericMethod == QueryableMethodProvider.WhereMethodInfo: return ProcessWhere( source, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); diff --git a/src/EFCore/Query/QueryableMethodProvider.cs b/src/EFCore/Query/QueryableMethodProvider.cs index 6f97ef2f38f..3fc326ef115 100644 --- a/src/EFCore/Query/QueryableMethodProvider.cs +++ b/src/EFCore/Query/QueryableMethodProvider.cs @@ -40,7 +40,6 @@ private static bool IsExpressionOfFunc(Type type, int funcGenericArgs = 2) public static MethodInfo MaxWithSelectorMethodInfo { get; } public static MethodInfo MaxWithoutSelectorMethodInfo { get; } - public static MethodInfo ElementAtMethodInfo { get; } public static MethodInfo ElementAtOrDefaultMethodInfo { get; } public static MethodInfo FirstWithoutPredicateMethodInfo { get; } @@ -86,6 +85,22 @@ private static bool IsExpressionOfFunc(Type type, int funcGenericArgs = 2) public static IDictionary AverageWithoutSelectorMethodInfos { get; } public static IDictionary AverageWithSelectorMethodInfos { get; } + public static bool IsSumWithoutSelectorMethodInfo(MethodInfo methodInfo) + => SumWithoutSelectorMethodInfos.Values.Contains(methodInfo); + public static bool IsSumWithSelectorMethodInfo(MethodInfo methodInfo) + => methodInfo.IsGenericMethod + && SumWithSelectorMethodInfos.Values.Contains(methodInfo.GetGenericMethodDefinition()); + public static bool IsSumMethodInfo(MethodInfo methodInfo) + => IsSumWithoutSelectorMethodInfo(methodInfo) || IsSumWithSelectorMethodInfo(methodInfo); + + public static bool IsAverageWithoutSelectorMethodInfo(MethodInfo methodInfo) + => AverageWithoutSelectorMethodInfos.Values.Contains(methodInfo); + public static bool IsAverageWithSelectorMethodInfo(MethodInfo methodInfo) + => methodInfo.IsGenericMethod + && AverageWithSelectorMethodInfos.Values.Contains(methodInfo.GetGenericMethodDefinition()); + public static bool IsAverageMethodInfo(MethodInfo methodInfo) + => IsAverageWithoutSelectorMethodInfo(methodInfo) || IsAverageWithSelectorMethodInfo(methodInfo); + static QueryableMethodProvider() { var queryableMethods = typeof(Queryable).GetTypeInfo() diff --git a/src/EFCore/Query/QueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore/Query/QueryableMethodTranslatingExpressionVisitor.cs index 1203394b52c..50796808d15 100644 --- a/src/EFCore/Query/QueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore/Query/QueryableMethodTranslatingExpressionVisitor.cs @@ -31,49 +31,50 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { - if (methodCallExpression.Method.DeclaringType == typeof(Queryable) - || methodCallExpression.Method.DeclaringType == typeof(QueryableExtensions)) + var method = methodCallExpression.Method; + if (method.DeclaringType == typeof(Queryable) || method.DeclaringType == typeof(QueryableExtensions)) { var source = Visit(methodCallExpression.Arguments[0]); if (source is ShapedQueryExpression shapedQueryExpression) { - var argumentCount = methodCallExpression.Arguments.Count; - switch (methodCallExpression.Method.Name) + var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; + switch (method.Name) { - case nameof(Queryable.Aggregate): - // Don't know - break; + case nameof(Queryable.All) + when genericMethod == QueryableMethodProvider.AllMethodInfo: + shapedQueryExpression.ResultCardinality = ResultCardinality.Single; + return TranslateAll(shapedQueryExpression, GetLambdaExpressionFromArgument(1)); - case nameof(Queryable.All): + case nameof(Queryable.Any) + when genericMethod == QueryableMethodProvider.AnyWithoutPredicateMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.Single; - return TranslateAll( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + return TranslateAny(shapedQueryExpression, null); - case nameof(Queryable.Any): + case nameof(Queryable.Any) + when genericMethod == QueryableMethodProvider.AnyWithPredicateMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.Single; - return TranslateAny( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null); + return TranslateAny(shapedQueryExpression, GetLambdaExpressionFromArgument(1)); - case nameof(Queryable.AsQueryable): + case nameof(Queryable.AsQueryable) + when genericMethod == QueryableMethodProvider.AsQueryableMethodInfo: return source; - case nameof(Queryable.Average): + case nameof(Queryable.Average) + when QueryableMethodProvider.IsAverageWithoutSelectorMethodInfo(method): shapedQueryExpression.ResultCardinality = ResultCardinality.Single; - return TranslateAverage( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type); + return TranslateAverage(shapedQueryExpression, null, methodCallExpression.Type); - case nameof(Queryable.Cast): - return TranslateCast(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]); + case nameof(Queryable.Average) + when QueryableMethodProvider.IsAverageWithSelectorMethodInfo(method): + shapedQueryExpression.ResultCardinality = ResultCardinality.Single; + return TranslateAverage(shapedQueryExpression, GetLambdaExpressionFromArgument(1), methodCallExpression.Type); - case nameof(Queryable.Concat): + case nameof(Queryable.Cast) + when genericMethod == QueryableMethodProvider.CastMethodInfo: + return TranslateCast(shapedQueryExpression, method.GetGenericArguments()[0]); + + case nameof(Queryable.Concat) + when genericMethod == QueryableMethodProvider.ConcatMethodInfo: { var source2 = Visit(methodCallExpression.Arguments[1]); if (source2 is ShapedQueryExpression innerShapedQueryExpression) @@ -82,44 +83,48 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp shapedQueryExpression, innerShapedQueryExpression); } - } - break; + } case nameof(Queryable.Contains) - when argumentCount == 2: + when genericMethod == QueryableMethodProvider.ContainsMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.Single; return TranslateContains(shapedQueryExpression, methodCallExpression.Arguments[1]); - case nameof(Queryable.Count): + case nameof(Queryable.Count) + when genericMethod == QueryableMethodProvider.CountWithoutPredicateMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.Single; - return TranslateCount( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null); - - case nameof(Queryable.DefaultIfEmpty): - return TranslateDefaultIfEmpty( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1] - : null); + return TranslateCount(shapedQueryExpression, null); + + case nameof(Queryable.Count) + when genericMethod == QueryableMethodProvider.CountWithPredicateMethodInfo: + shapedQueryExpression.ResultCardinality = ResultCardinality.Single; + return TranslateCount(shapedQueryExpression, GetLambdaExpressionFromArgument(1)); + + case nameof(Queryable.DefaultIfEmpty) + when genericMethod == QueryableMethodProvider.DefaultIfEmptyWithoutArgumentMethodInfo: + return TranslateDefaultIfEmpty(shapedQueryExpression, null); + + case nameof(Queryable.DefaultIfEmpty) + when genericMethod == QueryableMethodProvider.DefaultIfEmptyWithArgumentMethodInfo: + return TranslateDefaultIfEmpty(shapedQueryExpression, methodCallExpression.Arguments[1]); case nameof(Queryable.Distinct) - when argumentCount == 1: + when genericMethod == QueryableMethodProvider.DistinctMethodInfo: return TranslateDistinct(shapedQueryExpression); - case nameof(Queryable.ElementAt): + case nameof(Queryable.ElementAt) + when genericMethod == QueryableMethodProvider.ElementAtMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.Single; return TranslateElementAtOrDefault(shapedQueryExpression, methodCallExpression.Arguments[1], false); - case nameof(Queryable.ElementAtOrDefault): + case nameof(Queryable.ElementAtOrDefault) + when genericMethod == QueryableMethodProvider.ElementAtOrDefaultMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.SingleOrDefault; return TranslateElementAtOrDefault(shapedQueryExpression, methodCallExpression.Arguments[1], true); case nameof(Queryable.Except) - when argumentCount == 2: + when genericMethod == QueryableMethodProvider.ExceptMethodInfo: { var source2 = Visit(methodCallExpression.Arguments[1]); if (source2 is ShapedQueryExpression innerShapedQueryExpression) @@ -128,317 +133,241 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp shapedQueryExpression, innerShapedQueryExpression); } + break; } - break; + case nameof(Queryable.First) + when genericMethod == QueryableMethodProvider.FirstWithoutPredicateMethodInfo: + shapedQueryExpression.ResultCardinality = ResultCardinality.Single; + return TranslateFirstOrDefault(shapedQueryExpression, null, methodCallExpression.Type, false); - case nameof(Queryable.First): + case nameof(Queryable.First) + when genericMethod == QueryableMethodProvider.FirstWithPredicateMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.Single; - return TranslateFirstOrDefault( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - false); - - case nameof(Queryable.FirstOrDefault): + return TranslateFirstOrDefault(shapedQueryExpression, GetLambdaExpressionFromArgument(1), methodCallExpression.Type, false); + + case nameof(Queryable.FirstOrDefault) + when genericMethod == QueryableMethodProvider.FirstOrDefaultWithoutPredicateMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.SingleOrDefault; - return TranslateFirstOrDefault( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - true); - - case nameof(Queryable.GroupBy): - { - var keySelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - if (methodCallExpression.Arguments[argumentCount - 1] is ConstantExpression) - { - // This means last argument is EqualityComparer on key - // which is not supported - break; - } + return TranslateFirstOrDefault(shapedQueryExpression, null, methodCallExpression.Type, true); - switch (argumentCount) - { - case 2: - return TranslateGroupBy( - shapedQueryExpression, - keySelector, - null, - null); - - case 3: - var lambda = methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(); - if (lambda.Parameters.Count == 1) - { - return TranslateGroupBy( - shapedQueryExpression, - keySelector, - lambda, - null); - } - else - { - return TranslateGroupBy( - shapedQueryExpression, - keySelector, - null, - lambda); - } - - case 4: - return TranslateGroupBy( - shapedQueryExpression, - keySelector, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote()); - } - } + case nameof(Queryable.FirstOrDefault) + when genericMethod == QueryableMethodProvider.FirstOrDefaultWithPredicateMethodInfo: + shapedQueryExpression.ResultCardinality = ResultCardinality.SingleOrDefault; + return TranslateFirstOrDefault(shapedQueryExpression, GetLambdaExpressionFromArgument(1), methodCallExpression.Type, true); - break; + case nameof(Queryable.GroupBy) + when genericMethod == QueryableMethodProvider.GroupByWithKeySelectorMethodInfo: + return TranslateGroupBy(shapedQueryExpression, GetLambdaExpressionFromArgument(1), null, null); + + case nameof(Queryable.GroupBy) + when genericMethod == QueryableMethodProvider.GroupByWithKeyElementSelectorMethodInfo: + return TranslateGroupBy(shapedQueryExpression, GetLambdaExpressionFromArgument(1), GetLambdaExpressionFromArgument(2), null); + + case nameof(Queryable.GroupBy) + when genericMethod == QueryableMethodProvider.GroupByWithKeyElementResultSelectorMethodInfo: + return TranslateGroupBy(shapedQueryExpression, GetLambdaExpressionFromArgument(1), GetLambdaExpressionFromArgument(2), GetLambdaExpressionFromArgument(3)); + + case nameof(Queryable.GroupBy) + when genericMethod == QueryableMethodProvider.GroupByWithKeyResultSelectorMethodInfo: + return TranslateGroupBy(shapedQueryExpression, GetLambdaExpressionFromArgument(1), null, GetLambdaExpressionFromArgument(2)); case nameof(Queryable.GroupJoin) - when argumentCount == 5: + when genericMethod == QueryableMethodProvider.GroupJoinMethodInfo: { - var innerSource = Visit(methodCallExpression.Arguments[1]); - if (innerSource is ShapedQueryExpression innerShapedQueryExpression) + if (Visit(methodCallExpression.Arguments[1]) is ShapedQueryExpression innerShapedQueryExpression) { - return TranslateGroupJoin( - shapedQueryExpression, - innerShapedQueryExpression, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + return TranslateGroupJoin(shapedQueryExpression, innerShapedQueryExpression, GetLambdaExpressionFromArgument(2), GetLambdaExpressionFromArgument(3), GetLambdaExpressionFromArgument(4)); } - } - break; + } case nameof(Queryable.Intersect) - when argumentCount == 2: + when genericMethod == QueryableMethodProvider.IntersectMethodInfo: { - var source2 = Visit(methodCallExpression.Arguments[1]); - if (source2 is ShapedQueryExpression innerShapedQueryExpression) + if (Visit(methodCallExpression.Arguments[1]) is ShapedQueryExpression innerShapedQueryExpression) { return TranslateIntersect( shapedQueryExpression, innerShapedQueryExpression); } - } - break; + } case nameof(Queryable.Join) - when argumentCount == 5: + when genericMethod == QueryableMethodProvider.JoinMethodInfo: { - var innerSource = Visit(methodCallExpression.Arguments[1]); - if (innerSource is ShapedQueryExpression innerShapedQueryExpression) + if (Visit(methodCallExpression.Arguments[1]) is ShapedQueryExpression innerShapedQueryExpression) { - return TranslateJoin( - shapedQueryExpression, - innerShapedQueryExpression, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + return TranslateJoin(shapedQueryExpression, innerShapedQueryExpression, GetLambdaExpressionFromArgument(2), GetLambdaExpressionFromArgument(3), GetLambdaExpressionFromArgument(4)); } - } - break; + } case nameof(QueryableExtensions.LeftJoin) - when argumentCount == 5: + when genericMethod == QueryableExtensions.LeftJoinMethodInfo: { - var innerSource = Visit(methodCallExpression.Arguments[1]); - if (innerSource is ShapedQueryExpression innerShapedQueryExpression) + if (Visit(methodCallExpression.Arguments[1]) is ShapedQueryExpression innerShapedQueryExpression) { - return TranslateLeftJoin( - shapedQueryExpression, - innerShapedQueryExpression, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + return TranslateLeftJoin(shapedQueryExpression, innerShapedQueryExpression, GetLambdaExpressionFromArgument(2), GetLambdaExpressionFromArgument(3), GetLambdaExpressionFromArgument(4)); } + break; } - break; + case nameof(Queryable.Last) + when genericMethod == QueryableMethodProvider.LastWithoutPredicateMethodInfo: + shapedQueryExpression.ResultCardinality = ResultCardinality.Single; + return TranslateLastOrDefault(shapedQueryExpression, null, methodCallExpression.Type, false); - case nameof(Queryable.Last): + case nameof(Queryable.Last) + when genericMethod == QueryableMethodProvider.LastWithPredicateMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.Single; - return TranslateLastOrDefault( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - false); - - case nameof(Queryable.LastOrDefault): + return TranslateLastOrDefault(shapedQueryExpression, GetLambdaExpressionFromArgument(1), methodCallExpression.Type, false); + + case nameof(Queryable.LastOrDefault) + when genericMethod == QueryableMethodProvider.LastOrDefaultWithoutPredicateMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.SingleOrDefault; - return TranslateLastOrDefault( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - true); - - case nameof(Queryable.LongCount): + return TranslateLastOrDefault(shapedQueryExpression, null, methodCallExpression.Type, true); + + case nameof(Queryable.LastOrDefault) + when genericMethod == QueryableMethodProvider.LastOrDefaultWithPredicateMethodInfo: + shapedQueryExpression.ResultCardinality = ResultCardinality.SingleOrDefault; + return TranslateLastOrDefault(shapedQueryExpression, GetLambdaExpressionFromArgument(1), methodCallExpression.Type, true); + + case nameof(Queryable.LongCount) + when genericMethod == QueryableMethodProvider.LongCountWithoutPredicateMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.Single; - return TranslateLongCount( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null); + return TranslateLongCount(shapedQueryExpression, null); - case nameof(Queryable.Max): + case nameof(Queryable.LongCount) + when genericMethod == QueryableMethodProvider.LongCountWithPredicateMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.Single; - return TranslateMax( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type); - - case nameof(Queryable.Min): + return TranslateLongCount(shapedQueryExpression, GetLambdaExpressionFromArgument(1)); + + case nameof(Queryable.Max) + when genericMethod == QueryableMethodProvider.MaxWithoutSelectorMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.Single; - return TranslateMin( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type); + return TranslateMax(shapedQueryExpression, null, methodCallExpression.Type); - case nameof(Queryable.OfType): - return TranslateOfType(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]); + case nameof(Queryable.Max) + when genericMethod == QueryableMethodProvider.MaxWithSelectorMethodInfo: + shapedQueryExpression.ResultCardinality = ResultCardinality.Single; + return TranslateMax(shapedQueryExpression, GetLambdaExpressionFromArgument(1), methodCallExpression.Type); + + case nameof(Queryable.Min) + when genericMethod == QueryableMethodProvider.MinWithoutSelectorMethodInfo: + shapedQueryExpression.ResultCardinality = ResultCardinality.Single; + return TranslateMin(shapedQueryExpression, null, methodCallExpression.Type); + + case nameof(Queryable.Min) + when genericMethod == QueryableMethodProvider.MinWithSelectorMethodInfo: + shapedQueryExpression.ResultCardinality = ResultCardinality.Single; + return TranslateMin(shapedQueryExpression, GetLambdaExpressionFromArgument(1), methodCallExpression.Type); + + case nameof(Queryable.OfType) + when genericMethod == QueryableMethodProvider.OfTypeMethodInfo: + return TranslateOfType(shapedQueryExpression, method.GetGenericArguments()[0]); case nameof(Queryable.OrderBy) - when argumentCount == 2: - return TranslateOrderBy( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), - true); + when genericMethod == QueryableMethodProvider.OrderByMethodInfo: + return TranslateOrderBy(shapedQueryExpression, GetLambdaExpressionFromArgument(1), true); case nameof(Queryable.OrderByDescending) - when argumentCount == 2: - return TranslateOrderBy( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), - false); + when genericMethod == QueryableMethodProvider.OrderByDescendingMethodInfo: + return TranslateOrderBy(shapedQueryExpression, GetLambdaExpressionFromArgument(1), false); - case nameof(Queryable.Reverse): + case nameof(Queryable.Reverse) + when genericMethod == QueryableMethodProvider.ReverseMethodInfo: return TranslateReverse(shapedQueryExpression); - case nameof(Queryable.Select): - return TranslateSelect( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + case nameof(Queryable.Select) + when genericMethod == QueryableMethodProvider.SelectMethodInfo: + return TranslateSelect(shapedQueryExpression, GetLambdaExpressionFromArgument(1)); - case nameof(Queryable.SelectMany): - return methodCallExpression.Arguments.Count == 2 - ? TranslateSelectMany( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()) - : TranslateSelectMany( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote()); + case nameof(Queryable.SelectMany) + when genericMethod == QueryableMethodProvider.SelectManyWithoutCollectionSelectorMethodInfo: + return TranslateSelectMany(shapedQueryExpression, GetLambdaExpressionFromArgument(1)); - case nameof(Queryable.SequenceEqual): - // don't know - break; + case nameof(Queryable.SelectMany) + when genericMethod == QueryableMethodProvider.SelectManyWithCollectionSelectorMethodInfo: + return TranslateSelectMany(shapedQueryExpression, GetLambdaExpressionFromArgument(1), GetLambdaExpressionFromArgument(2)); + + case nameof(Queryable.Single) + when genericMethod == QueryableMethodProvider.SingleWithoutPredicateMethodInfo: + shapedQueryExpression.ResultCardinality = ResultCardinality.Single; + return TranslateSingleOrDefault(shapedQueryExpression, null, methodCallExpression.Type, false); - case nameof(Queryable.Single): + case nameof(Queryable.Single) + when genericMethod == QueryableMethodProvider.SingleWithPredicateMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.Single; - return TranslateSingleOrDefault( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - false); - - case nameof(Queryable.SingleOrDefault): + return TranslateSingleOrDefault(shapedQueryExpression, GetLambdaExpressionFromArgument(1), methodCallExpression.Type, false); + + case nameof(Queryable.SingleOrDefault) + when genericMethod == QueryableMethodProvider.SingleOrDefaultWithoutPredicateMethodInfo: shapedQueryExpression.ResultCardinality = ResultCardinality.SingleOrDefault; - return TranslateSingleOrDefault( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - true); - - case nameof(Queryable.Skip): + return TranslateSingleOrDefault(shapedQueryExpression, null, methodCallExpression.Type, true); + + case nameof(Queryable.SingleOrDefault) + when genericMethod == QueryableMethodProvider.SingleOrDefaultWithPredicateMethodInfo: + shapedQueryExpression.ResultCardinality = ResultCardinality.SingleOrDefault; + return TranslateSingleOrDefault(shapedQueryExpression, GetLambdaExpressionFromArgument(1), methodCallExpression.Type, true); + + case nameof(Queryable.Skip) + when genericMethod == QueryableMethodProvider.SkipMethodInfo: return TranslateSkip(shapedQueryExpression, methodCallExpression.Arguments[1]); - case nameof(Queryable.SkipWhile): - return TranslateSkipWhile( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + case nameof(Queryable.SkipWhile) + when genericMethod == QueryableMethodProvider.SkipWhileMethodInfo: + return TranslateSkipWhile(shapedQueryExpression, GetLambdaExpressionFromArgument(1)); + + case nameof(Queryable.Sum) + when QueryableMethodProvider.IsSumWithoutSelectorMethodInfo(method): + shapedQueryExpression.ResultCardinality = ResultCardinality.Single; + return TranslateSum(shapedQueryExpression, null, methodCallExpression.Type); - case nameof(Queryable.Sum): + case nameof(Queryable.Sum) + when QueryableMethodProvider.IsSumWithSelectorMethodInfo(method): shapedQueryExpression.ResultCardinality = ResultCardinality.Single; - return TranslateSum( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type); - - case nameof(Queryable.Take): + return TranslateSum(shapedQueryExpression, GetLambdaExpressionFromArgument(1), methodCallExpression.Type); + + case nameof(Queryable.Take) + when genericMethod == QueryableMethodProvider.TakeMethodInfo: return TranslateTake(shapedQueryExpression, methodCallExpression.Arguments[1]); - case nameof(Queryable.TakeWhile): - return TranslateTakeWhile( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + case nameof(Queryable.TakeWhile) + when genericMethod == QueryableMethodProvider.TakeWhileMethodInfo: + return TranslateTakeWhile(shapedQueryExpression, GetLambdaExpressionFromArgument(1)); case nameof(Queryable.ThenBy) - when argumentCount == 2: - return TranslateThenBy( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), - true); + when genericMethod == QueryableMethodProvider.ThenByMethodInfo: + return TranslateThenBy(shapedQueryExpression, GetLambdaExpressionFromArgument(1), true); case nameof(Queryable.ThenByDescending) - when argumentCount == 2: - return TranslateThenBy( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), - false); + when genericMethod == QueryableMethodProvider.ThenByDescendingMethodInfo: + return TranslateThenBy(shapedQueryExpression, GetLambdaExpressionFromArgument(1), false); case nameof(Queryable.Union) - when argumentCount == 2: + when genericMethod == QueryableMethodProvider.UnionMethodInfo: { - var source2 = Visit(methodCallExpression.Arguments[1]); - if (source2 is ShapedQueryExpression innerShapedQueryExpression) + if (Visit(methodCallExpression.Arguments[1]) is ShapedQueryExpression innerShapedQueryExpression) { - return TranslateUnion( - shapedQueryExpression, - innerShapedQueryExpression); + return TranslateUnion(shapedQueryExpression, innerShapedQueryExpression); } - } - break; + } - case nameof(Queryable.Where): - return TranslateWhere( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + case nameof(Queryable.Where) + when genericMethod == QueryableMethodProvider.WhereMethodInfo: + return TranslateWhere(shapedQueryExpression, GetLambdaExpressionFromArgument(1)); - case nameof(Queryable.Zip): - // Don't know - break; + LambdaExpression GetLambdaExpressionFromArgument(int argumentIndex) => methodCallExpression.Arguments[argumentIndex].UnwrapLambdaFromQuote(); } } } return _subquery ? (Expression)null - : throw new NotImplementedException("Unhandled method: " + methodCallExpression.Method.Name); + : throw new NotImplementedException("Unhandled method: " + method.Name); } private class EntityShaperNullableMarkingExpressionVisitor : ExpressionVisitor