Skip to content

Commit

Permalink
TryFindAverageMethod
Browse files Browse the repository at this point in the history
  • Loading branch information
StefH committed May 1, 2024
1 parent 98b1bae commit 02af592
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
22 changes: 6 additions & 16 deletions src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2072,8 +2072,8 @@ private bool TryParseEnumerable(Expression instance, Type elementType, string me
_it = outerIt;
_parent = oldParent;

var typeToCheckForTypeOfString = type ?? instance.Type;
if (typeToCheckForTypeOfString == typeof(string) && _methodFinder.ContainsMethod(typeToCheckForTypeOfString, methodName, false, instance, ref args))
var theType = type ?? instance.Type;
if (theType == typeof(string) && _methodFinder.ContainsMethod(theType, methodName, false, instance, ref args))
{
// In case the type is a string, and does contain the methodName (like "IndexOf"), then return false to indicate that the methodName is not an Enumerable method.
expression = null;
Expand All @@ -2096,21 +2096,11 @@ private bool TryParseEnumerable(Expression instance, Type elementType, string me
callType = typeof(Queryable);
}

// #633 - For Average without any arguments, try to find the non-generic Average method for the supplied 'type'.
if (methodName == nameof(Enumerable.Average) && args.Length == 0)
// #633 - For Average without any arguments, try to find the non-generic Average method on the callType for the supplied parameter type.
if (methodName == nameof(Enumerable.Average) && args.Length == 0 && _methodFinder.TryFindAverageMethod(callType, theType, out var averageMethod))
{
var averageMethod = callType
.GetMethods()
.Where(m => m is { Name: nameof(Enumerable.Average), IsGenericMethodDefinition: false })
.SelectMany(m => m.GetParameters(), (m, p) => new { Method = m, Parameter = p })
.Where(x => x.Parameter.ParameterType == type)
.Select(x => x.Method)
.FirstOrDefault();
if (averageMethod != null)
{
expression = Expression.Call(null, averageMethod, new[] { instance });
return true;
}
expression = Expression.Call(null, averageMethod, new[] { instance });
return true;
}

Type[] typeArgs;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Dynamic.Core.Validation;
using System.Linq.Expressions;
using System.Reflection;
Expand Down Expand Up @@ -44,6 +45,19 @@ public MethodFinder(ParsingConfig parsingConfig, IExpressionHelper expressionHel
_expressionHelper = Check.NotNull(expressionHelper);
}

public bool TryFindAverageMethod(Type callType, Type parameterType, [NotNullWhen(true)] out MethodInfo? averageMethod)
{
averageMethod = callType
.GetMethods()
.Where(m => m is { Name: nameof(Enumerable.Average), IsGenericMethodDefinition: false })
.SelectMany(m => m.GetParameters(), (m, p) => new { Method = m, Parameter = p })
.Where(x => x.Parameter.ParameterType == parameterType)
.Select(x => x.Method)
.FirstOrDefault();

return averageMethod != null;
}

public void CheckAggregateMethodAndTryUpdateArgsToMatchMethodArgs(string methodName, ref Expression[] args)
{
if (methodName is nameof(IAggregateSignatures.Average) or nameof(IAggregateSignatures.Sum))
Expand Down
4 changes: 2 additions & 2 deletions src/System.Linq.Dynamic.Core/Util/QueryableMethodFinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ public static MethodInfo GetGenericMethod(string name)
{
return typeof(Queryable).GetTypeInfo().GetDeclaredMethods(name).Single(mi => mi.IsGenericMethod);
}

public static MethodInfo GetMethod(string name, Type argumentType, Type returnType, int parameterCount = 0, Func<MethodInfo, bool>? predicate = null) =>
GetMethod(name, returnType, parameterCount, mi => mi.ToString().Contains(argumentType.ToString()) && ((predicate == null) || predicate(mi)));
GetMethod(name, returnType, parameterCount, mi => mi.ToString().Contains(argumentType.ToString()) && (predicate == null || predicate(mi)));

public static MethodInfo GetMethod(string name, Type returnType, int parameterCount = 0, Func<MethodInfo, bool>? predicate = null)
{
Expand Down

0 comments on commit 02af592

Please sign in to comment.