Skip to content

Commit

Permalink
Fix Aggregate methods Average and Sum
Browse files Browse the repository at this point in the history
  • Loading branch information
StefH committed Apr 11, 2024
1 parent 8dff954 commit 7430533
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 14 deletions.
5 changes: 4 additions & 1 deletion src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2064,6 +2064,9 @@ private Expression ParseEnumerable(Expression instance, Type elementType, string
return Expression.Call(instance, dictionaryMethod, args);
}

// #794 - Check if the method is an aggregate (Average or Sum) method and try to update the arguments to match the method arguments
_methodFinder.CheckAggregateMethodAndTryUpdateArgsToMatchMethodArgs(methodName, ref args);

var callType = typeof(Enumerable);
if (type != null && TypeHelper.FindGenericType(typeof(IQueryable<>), type) != null && _methodFinder.ContainsMethod(type, methodName))
{
Expand All @@ -2081,7 +2084,7 @@ private Expression ParseEnumerable(Expression instance, Type elementType, string
typeArgs = new[] { ResolveTypeFromArgumentExpression(methodName, args[0]) };
args = new Expression[0];
}
else if (new[] { "Min", "Max", "Select", "OrderBy", "OrderByDescending", "ThenBy", "ThenByDescending", "GroupBy" }.Contains(methodName))
else if (new[] { "Max", "Min", "Select", "OrderBy", "OrderByDescending", "ThenBy", "ThenByDescending", "GroupBy" }.Contains(methodName))
{
if (args.Length == 2)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,55 @@ internal class MethodFinder
private readonly ParsingConfig _parsingConfig;
private readonly IExpressionHelper _expressionHelper;

/// <summary>
/// #794
/// </summary>
private interface IAggregateSignatures
{
void Average(decimal? selector);
void Average(decimal selector);
void Average(double? selector);
void Average(double selector);
void Average(float? selector);
void Average(float selector);
void Average(int? selector);
void Average(int selector);
void Average(long? selector);
void Average(long selector);

void Sum(decimal? selector);
void Sum(decimal selector);
void Sum(double? selector);
void Sum(double selector);
void Sum(float? selector);
void Sum(float selector);
void Sum(int? selector);
void Sum(int selector);
void Sum(long? selector);
void Sum(long selector);
}

public MethodFinder(ParsingConfig parsingConfig, IExpressionHelper expressionHelper)
{
_parsingConfig = Check.NotNull(parsingConfig);
_expressionHelper = Check.NotNull(expressionHelper);
}

public void CheckAggregateMethodAndTryUpdateArgsToMatchMethodArgs(string methodName, ref Expression[] args)
{
if (methodName is nameof(IAggregateSignatures.Average) or nameof(IAggregateSignatures.Sum))
{
ContainsMethod(typeof(IAggregateSignatures), methodName, false, null, ref args);
}
}

public bool ContainsMethod(Type type, string methodName, bool staticAccess = true)
{
Check.NotNull(type);

#if !(NETFX_CORE || WINDOWS_APP || UAP10_0 || NETSTANDARD)
var flags = BindingFlags.Public | BindingFlags.DeclaredOnly | (staticAccess ? BindingFlags.Static : BindingFlags.Instance);
return type.FindMembers(MemberTypes.Method, flags, Type.FilterNameIgnoreCase, methodName).Any();
var flags = BindingFlags.Public | BindingFlags.DeclaredOnly | (staticAccess ? BindingFlags.Static : BindingFlags.Instance);
return type.FindMembers(MemberTypes.Method, flags, Type.FilterNameIgnoreCase, methodName).Any();
#else
return type.GetTypeInfo().DeclaredMethods.Any(m => (m.IsStatic || !staticAccess) && m.Name.Equals(methodName, StringComparison.OrdinalIgnoreCase));
#endif
Expand All @@ -40,17 +76,17 @@ public bool ContainsMethod(Type type, string methodName, bool staticAccess, Expr

public int FindMethod(Type? type, string methodName, bool staticAccess, ref Expression? instance, ref Expression[] args, out MethodBase? method)
{
#if !(NETFX_CORE || WINDOWS_APP || UAP10_0 || NETSTANDARD)
BindingFlags flags = BindingFlags.Public | BindingFlags.DeclaredOnly | (staticAccess ? BindingFlags.Static : BindingFlags.Instance);
foreach (Type t in SelfAndBaseTypes(type))
#if !(NETFX_CORE || WINDOWS_APP || UAP10_0 || NETSTANDARD)
BindingFlags flags = BindingFlags.Public | BindingFlags.DeclaredOnly | (staticAccess ? BindingFlags.Static : BindingFlags.Instance);
foreach (Type t in SelfAndBaseTypes(type))
{
MemberInfo[] members = t.FindMembers(MemberTypes.Method, flags, Type.FilterNameIgnoreCase, methodName);
int count = FindBestMethodBasedOnArguments(members.Cast<MethodBase>(), ref args, out method);
if (count != 0)
{
MemberInfo[] members = t.FindMembers(MemberTypes.Method, flags, Type.FilterNameIgnoreCase, methodName);
int count = FindBestMethodBasedOnArguments(members.Cast<MethodBase>(), ref args, out method);
if (count != 0)
{
return count;
}
return count;
}
}
#else
foreach (Type t in SelfAndBaseTypes(type))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,15 @@ internal class TestClass794
[InlineData("Sum")]
public void DynamicExpressionParser_ParseLambda_Aggregate(string operation)
{
foreach (var propertyInfo in typeof(TestClass794).GetProperties().Where(p => p.Name == "ByteValue"))
foreach (var propertyInfo in typeof(TestClass794).GetProperties())
{
var expression = $"{operation}({propertyInfo.Name})"; // e.g., "Sum(ByteValue)"

// Act
// Act IEnumerable
DynamicExpressionParser.ParseLambda(itType: typeof(IEnumerable<TestClass794>), resultType: typeof(double?), expression: expression);

// Act IEnumerable
DynamicExpressionParser.ParseLambda(itType: typeof(IQueryable<TestClass794>), resultType: typeof(double?), expression: expression);
}
}

Expand Down

0 comments on commit 7430533

Please sign in to comment.