Skip to content

Commit

Permalink
Push join argument down into subquery if it contains a Window Function
Browse files Browse the repository at this point in the history
  • Loading branch information
virzak committed Jan 29, 2024
1 parent ce835fb commit 1e99531
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
namespace Zomp.EFCore.WindowFunctions.Query.Internal;

internal sealed class CompareNameAndDeclaringType : IEqualityComparer<MethodInfo>
{
public static CompareNameAndDeclaringType Default { get; } = new();

public bool Equals(MethodInfo? x, MethodInfo? y)
{
if (x is null || y is null)
{
return x is null && y is null;
}

return x.Name.Equals(y.Name, StringComparison.Ordinal) && x.DeclaringType == y.DeclaringType;
}

public int GetHashCode(MethodInfo method)
{
return HashCode.Combine(method.Name.GetHashCode(StringComparison.Ordinal), method.DeclaringType?.GetHashCode());
}
}
76 changes: 76 additions & 0 deletions src/Zomp.EFCore.WindowFunctions/Query/Internal/JoinDetector.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
namespace Zomp.EFCore.WindowFunctions.Query.Internal;

internal sealed class JoinDetector : ExpressionVisitor
{
private static readonly Dictionary<string, List<MethodInfo>> QueryableMethodGroups = typeof(Enumerable)
.GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.GroupBy(mi => mi.Name)
.ToDictionary(e => e.Key, l => l.ToList());

/*
private static readonly MethodInfo Where = GetMethod(
nameof(Enumerable.Where),
1,
types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) });
*/

public IList<MethodCallExpression> WindowFunctionsCollection { get; } = new List<MethodCallExpression>();

/// <inheritdoc/>
protected override Expression VisitMethodCall(MethodCallExpression node)
{
var method = node.Method;
if (method.DeclaringType == typeof(Queryable))
{
if (method.Name == nameof(Queryable.Join))
{
var wfd = new WindowFunctionDetectorInternal();
var joinArg = node.Arguments[1];

_ = wfd.Visit(joinArg);

if (wfd.WindowFunctionsCollection.Count == 0)
{
return base.VisitMethodCall(node);
}

var @base = (MethodCallExpression)base.VisitMethodCall(node);

var type = QueryableMethods.Join.MakeGenericMethod(method.GetGenericArguments());

var args = @base.Arguments;

var argForSubQuery = args[1];

var asSubQueryMethod = WindowFunctionsEvaluatableExpressionFilter.AsSubQueryMethod.MakeGenericMethod(argForSubQuery.Type.GetGenericArguments()[0]);
var withSubquery = Expression.Call(null, asSubQueryMethod, argForSubQuery);

var joinWithSubquery = Expression.Call(
null,
method,
args[0],
withSubquery,
args[2],
args[3],
args[4]);
return joinWithSubquery;
}
}

if (WindowFunctionsEvaluatableExpressionFilter.WindowFunctionMethods.Contains(node.Method, CompareNameAndDeclaringType.Default))
{
WindowFunctionsCollection.Add(node);
}

var retVal = base.VisitMethodCall(node);

return retVal;
}

private static MethodInfo GetMethod(string name, int genericParameterCount, Func<Type[], Type[]> parameterGenerator)
=> QueryableMethodGroups[name].Single(
mi => ((genericParameterCount == 0 && !mi.IsGenericMethod)
|| (mi.IsGenericMethod && mi.GetGenericArguments().Length == genericParameterCount))
&& mi.GetParameters().Select(e => e.ParameterType).SequenceEqual(
parameterGenerator(mi.IsGenericMethod ? mi.GetGenericArguments() : [])));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
namespace Zomp.EFCore.WindowFunctions.Query.Internal;

internal sealed class WindowFunctionDetectorInternal : ExpressionVisitor
{
public IList<MethodCallExpression> WindowFunctionsCollection { get; } = [];

protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (WindowFunctionsEvaluatableExpressionFilter.WindowFunctionMethods.Contains(node.Method, CompareNameAndDeclaringType.Default))
{
WindowFunctionsCollection.Add(node);
}

return base.VisitMethodCall(node);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace Zomp.EFCore.WindowFunctions.Query.Internal;

internal sealed class WindowFunctionDetector : ExpressionVisitor
internal sealed class WindowFunctionInsideWhereDetector : ExpressionVisitor
{
private static readonly ConstructorInfo ConObj = typeof(object).GetConstructor([])
?? throw new InvalidOperationException("Can't be null");
Expand Down Expand Up @@ -225,39 +225,4 @@ protected override Expression VisitParameter(ParameterExpression node)
return newProperty;
}
}

private sealed class WindowFunctionDetectorInternal : ExpressionVisitor
{
public IList<MethodCallExpression> WindowFunctionsCollection { get; } = [];

protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (WindowFunctionsEvaluatableExpressionFilter.WindowFunctionMethods.Contains(node.Method, CompareNameAndDeclaringType.Default))
{
WindowFunctionsCollection.Add(node);
}

return base.VisitMethodCall(node);
}
}

private sealed class CompareNameAndDeclaringType : IEqualityComparer<MethodInfo>
{
public static CompareNameAndDeclaringType Default { get; } = new();

public bool Equals(MethodInfo? x, MethodInfo? y)
{
if (x is null || y is null)
{
return x is null && y is null;
}

return x.Name.Equals(y.Name, StringComparison.Ordinal) && x.DeclaringType == y.DeclaringType;
}

public int GetHashCode(MethodInfo method)
{
return HashCode.Combine(method.Name.GetHashCode(StringComparison.Ordinal), method.DeclaringType?.GetHashCode());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,21 @@ public class WindowFunctionsRelationalQueryTranslationPreprocessor(QueryTranslat
/// <inheritdoc/>
public override Expression Process(Expression query)
{
var @base = base.Process(query);
var wfd = new WindowFunctionDetector();
var rewritten = wfd.Visit(@base);
return rewritten;
query = new InvocationExpressionRemovingExpressionVisitor().Visit(query);
query = NormalizeQueryableMethod(query);
query = new CallForwardingExpressionVisitor().Visit(query);
query = new NullCheckRemovingExpressionVisitor().Visit(query);
query = new SubqueryMemberPushdownExpressionVisitor(QueryCompilationContext.Model).Visit(query);
query = new JoinDetector().Visit(query);
query = new NavigationExpandingExpressionVisitor(
this,
QueryCompilationContext,
Dependencies.EvaluatableExpressionFilter,
Dependencies.NavigationExpansionExtensibilityHelper)
.Expand(query);
query = new QueryOptimizingExpressionVisitor().Visit(query);
query = new NullCheckRemovingExpressionVisitor().Visit(query);
query = new WindowFunctionInsideWhereDetector().Visit(query);
return query;
}
}
16 changes: 16 additions & 0 deletions tests/Zomp.EFCore.WindowFunctions.Testing/WhereTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,20 @@ public void MaxWithWhere()

Assert.Equal(expected, result.Single(), TestRowEqualityComparer.Default);
}

[Fact]
public void Join()
{
var query = DbContext.TestRows.Join(
DbContext.TestRows.Select(subRow => new
{
subRow.Id,
RowNumber = EF.Functions.RowNumber(EF.Functions.Over().OrderBy(subRow.Date).PartitionBy(subRow.Col1)),
}),
l => l.Id,
r => r.Id,
(l, r) => new { r.Id, r.RowNumber });

var queryStr = query.ToQueryString();
}
}

0 comments on commit 1e99531

Please sign in to comment.