From 743053361214bff570f3fdfffbd42ec2b21db76b Mon Sep 17 00:00:00 2001 From: Stef Heyenrath Date: Thu, 11 Apr 2024 20:27:39 +0200 Subject: [PATCH] Fix Aggregate methods Average and Sum --- .../Parser/ExpressionParser.cs | 5 +- .../Parser/SupportedMethods/MethodFinder.cs | 58 +++++++++++++++---- .../DynamicExpressionParserTests.cs | 7 ++- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs b/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs index e19a5f39..b7c9c304 100644 --- a/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs +++ b/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs @@ -2064,6 +2064,9 @@ private Expression ParseEnumerable(Expression instance, Type elementType, string return Expression.Call(instance, dictionaryMethod, args); } + // #794 - Check if the method is an aggregate (Average or Sum) method and try to update the arguments to match the method arguments + _methodFinder.CheckAggregateMethodAndTryUpdateArgsToMatchMethodArgs(methodName, ref args); + var callType = typeof(Enumerable); if (type != null && TypeHelper.FindGenericType(typeof(IQueryable<>), type) != null && _methodFinder.ContainsMethod(type, methodName)) { @@ -2081,7 +2084,7 @@ private Expression ParseEnumerable(Expression instance, Type elementType, string typeArgs = new[] { ResolveTypeFromArgumentExpression(methodName, args[0]) }; args = new Expression[0]; } - else if (new[] { "Min", "Max", "Select", "OrderBy", "OrderByDescending", "ThenBy", "ThenByDescending", "GroupBy" }.Contains(methodName)) + else if (new[] { "Max", "Min", "Select", "OrderBy", "OrderByDescending", "ThenBy", "ThenByDescending", "GroupBy" }.Contains(methodName)) { if (args.Length == 2) { diff --git a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/MethodFinder.cs b/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/MethodFinder.cs index 98afb028..0b0b2331 100644 --- a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/MethodFinder.cs +++ b/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/MethodFinder.cs @@ -10,19 +10,55 @@ internal class MethodFinder private readonly ParsingConfig _parsingConfig; private readonly IExpressionHelper _expressionHelper; + /// + /// #794 + /// + private interface IAggregateSignatures + { + void Average(decimal? selector); + void Average(decimal selector); + void Average(double? selector); + void Average(double selector); + void Average(float? selector); + void Average(float selector); + void Average(int? selector); + void Average(int selector); + void Average(long? selector); + void Average(long selector); + + void Sum(decimal? selector); + void Sum(decimal selector); + void Sum(double? selector); + void Sum(double selector); + void Sum(float? selector); + void Sum(float selector); + void Sum(int? selector); + void Sum(int selector); + void Sum(long? selector); + void Sum(long selector); + } + public MethodFinder(ParsingConfig parsingConfig, IExpressionHelper expressionHelper) { _parsingConfig = Check.NotNull(parsingConfig); _expressionHelper = Check.NotNull(expressionHelper); } + public void CheckAggregateMethodAndTryUpdateArgsToMatchMethodArgs(string methodName, ref Expression[] args) + { + if (methodName is nameof(IAggregateSignatures.Average) or nameof(IAggregateSignatures.Sum)) + { + ContainsMethod(typeof(IAggregateSignatures), methodName, false, null, ref args); + } + } + public bool ContainsMethod(Type type, string methodName, bool staticAccess = true) { Check.NotNull(type); #if !(NETFX_CORE || WINDOWS_APP || UAP10_0 || NETSTANDARD) - var flags = BindingFlags.Public | BindingFlags.DeclaredOnly | (staticAccess ? BindingFlags.Static : BindingFlags.Instance); - return type.FindMembers(MemberTypes.Method, flags, Type.FilterNameIgnoreCase, methodName).Any(); + var flags = BindingFlags.Public | BindingFlags.DeclaredOnly | (staticAccess ? BindingFlags.Static : BindingFlags.Instance); + return type.FindMembers(MemberTypes.Method, flags, Type.FilterNameIgnoreCase, methodName).Any(); #else return type.GetTypeInfo().DeclaredMethods.Any(m => (m.IsStatic || !staticAccess) && m.Name.Equals(methodName, StringComparison.OrdinalIgnoreCase)); #endif @@ -40,17 +76,17 @@ public bool ContainsMethod(Type type, string methodName, bool staticAccess, Expr public int FindMethod(Type? type, string methodName, bool staticAccess, ref Expression? instance, ref Expression[] args, out MethodBase? method) { -#if !(NETFX_CORE || WINDOWS_APP || UAP10_0 || NETSTANDARD) - BindingFlags flags = BindingFlags.Public | BindingFlags.DeclaredOnly | (staticAccess ? BindingFlags.Static : BindingFlags.Instance); - foreach (Type t in SelfAndBaseTypes(type)) +#if !(NETFX_CORE || WINDOWS_APP || UAP10_0 || NETSTANDARD) + BindingFlags flags = BindingFlags.Public | BindingFlags.DeclaredOnly | (staticAccess ? BindingFlags.Static : BindingFlags.Instance); + foreach (Type t in SelfAndBaseTypes(type)) + { + MemberInfo[] members = t.FindMembers(MemberTypes.Method, flags, Type.FilterNameIgnoreCase, methodName); + int count = FindBestMethodBasedOnArguments(members.Cast(), ref args, out method); + if (count != 0) { - MemberInfo[] members = t.FindMembers(MemberTypes.Method, flags, Type.FilterNameIgnoreCase, methodName); - int count = FindBestMethodBasedOnArguments(members.Cast(), ref args, out method); - if (count != 0) - { - return count; - } + return count; } + } #else foreach (Type t in SelfAndBaseTypes(type)) { diff --git a/test/System.Linq.Dynamic.Core.Tests/DynamicExpressionParserTests.cs b/test/System.Linq.Dynamic.Core.Tests/DynamicExpressionParserTests.cs index 5553f1a0..3b7d1e12 100644 --- a/test/System.Linq.Dynamic.Core.Tests/DynamicExpressionParserTests.cs +++ b/test/System.Linq.Dynamic.Core.Tests/DynamicExpressionParserTests.cs @@ -265,12 +265,15 @@ internal class TestClass794 [InlineData("Sum")] public void DynamicExpressionParser_ParseLambda_Aggregate(string operation) { - foreach (var propertyInfo in typeof(TestClass794).GetProperties().Where(p => p.Name == "ByteValue")) + foreach (var propertyInfo in typeof(TestClass794).GetProperties()) { var expression = $"{operation}({propertyInfo.Name})"; // e.g., "Sum(ByteValue)" - // Act + // Act IEnumerable DynamicExpressionParser.ParseLambda(itType: typeof(IEnumerable), resultType: typeof(double?), expression: expression); + + // Act IEnumerable + DynamicExpressionParser.ParseLambda(itType: typeof(IQueryable), resultType: typeof(double?), expression: expression); } }