Skip to content

Commit

Permalink
Add nested window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
virzak committed Feb 5, 2024
1 parent 424a8d1 commit f46ff96
Show file tree
Hide file tree
Showing 7 changed files with 426 additions and 150 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
namespace Zomp.EFCore.WindowFunctions.Query.Internal;

internal sealed class WindowFunctionInsideWhereDetector : ExpressionVisitor
/// <summary>
/// Creates Subquery for nested window functions and where clauses.
/// </summary>
public class WindowFunctionInsideWhereDetector : ExpressionVisitor
{
private const string Original = "Original";
private static readonly ConstructorInfo ConObj = typeof(object).GetConstructor([])
?? throw new InvalidOperationException("Can't be null");

Expand All @@ -10,115 +14,228 @@ internal sealed class WindowFunctionInsideWhereDetector : ExpressionVisitor
.GroupBy(mi => mi.Name)
.ToDictionary(e => e.Key, l => l.ToList());

/*
private static readonly MethodInfo Where = GetMethod(
nameof(Enumerable.Where),
1,
types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) });
*/
/// <summary>
/// Maps a where method call to window functions used inside it.
/// </summary>
private readonly Stack<Clause> clauseStack = new();

public IList<MethodCallExpression> WindowFunctionsCollection { get; } = new List<MethodCallExpression>();
private enum SubqueryType
{
Where,
Select,
OrderBy,
OrderByDescending,
ThenBy,
ThenByDescending,
}

/// <inheritdoc/>
protected override Expression VisitMethodCall(MethodCallExpression node)
{
var method = node.Method;
if (method.DeclaringType == typeof(Queryable))
var suspendWhereStack = clauseStack.TryPeek(out var top) && top.FirstArgument == node ? top : null;

if (suspendWhereStack is not null)
{
if (method.Name == nameof(Queryable.Where))
{
var wfd = new WindowFunctionDetectorInternal();
_ = wfd.Visit(node.Arguments[1]);
clauseStack.Pop();
}

var whereArg = node.Arguments[1];
var reduced = whereArg;
SubqueryType? subqueryType = node.Method.DeclaringType != typeof(Queryable) ? null
: node.Method.Name switch
{
nameof(Queryable.Where) => SubqueryType.Where,
nameof(Queryable.Select) => SubqueryType.Select,
nameof(Queryable.OrderBy) => SubqueryType.OrderBy,
nameof(Queryable.OrderByDescending) => SubqueryType.OrderByDescending,
nameof(Queryable.ThenBy) => SubqueryType.ThenBy,
nameof(Queryable.ThenByDescending) => SubqueryType.ThenByDescending,
_ => null,
};

while (reduced is UnaryExpression ue)
{
reduced = ue.Operand;
}
var wfNode = WindowFunctionsEvaluatableExpressionFilter.WindowFunctionMethods.Contains(node.Method, CompareNameAndDeclaringType.Default)
? node : null;

if (reduced is LambdaExpression whereLambda && wfd.WindowFunctionsCollection.Count > 0)
{
var whereParam = whereLambda.Parameters[0];
clauseStack.TryPeek(out var clause);

var typeArg = node.Arguments[0].Type.GenericTypeArguments[0];
var p = Expression.Parameter(typeArg, "o");
if (wfNode is not null)
{
_ = clause ?? throw new InvalidOperationException("Window Function outside of select or where?");

var wfToSelect = wfd.WindowFunctionsCollection.Select(e => e.ReplaceParameter(whereParam, p)).ToList();
clause.Add(wfNode, clause.WindowFunctionStack.Count);

var list = new Dictionary<Expression, Name_Type_And_Replacement>();
clause.WindowFunctionStack.Push(wfNode);
}

for (var i = 0; i < wfd.WindowFunctionsCollection.Count; ++i)
{
var item = wfd.WindowFunctionsCollection[i];
list[item] = new($"P{i}", item.Type, wfToSelect[i]);
}
if (subqueryType is not null)
{
clauseStack.Push(new(node));
}

var replacements = list.Values.ToList();
replacements.Insert(0, new("Original", typeArg, null!));
var @base = (MethodCallExpression)base.VisitMethodCall(node);

var anonType = CreateNewType(replacements);
if (wfNode is not null)
{
_ = clause ?? throw new InvalidOperationException("Window Function outside of select or where?");
clause.WindowFunctionStack.Pop();
}

var originalProperty = anonType.GetProperty("Original")
?? throw new InvalidOperationException("Can't be null, it was just created.");
var method = @base.Method;
if (subqueryType is not null)
{
var whereArg = @base.Arguments[1];
var reduced = whereArg;

var mbs = new List<MemberBinding>() { Expression.Bind(originalProperty, p) };
while (reduced is UnaryExpression ue)
{
reduced = ue.Operand;
}

foreach (var item in list)
{
////var item = zz.WindowFunctionsCollection[i];
////mbs.Add(Expression.Bind())
if (reduced is not LambdaExpression lambda)
{
throw new InvalidOperationException("Expression must be lambda");
}

var mi = anonType.GetMember(item.Value.Name)[0]
?? throw new InvalidOperationException($"{anonType} doesn't have property");
mbs.Add(Expression.Bind(mi, item.Value.Replacement));
}
var pop = clauseStack.Pop();
if (pop.WindowFunctions.Count > 0)
{
@base = subqueryType == SubqueryType.Select
? HandleSelectLambda(@base, lambda, pop)
: HandleLambda(@base, lambda, pop);
}
}

var @new = Expression.New(anonType);
var memberInit = Expression.MemberInit(@new, mbs);
var lambda = Expression.Lambda(memberInit, p);
if (suspendWhereStack is not null)
{
clauseStack.Push(suspendWhereStack);
}

node = Expression.Call(
null,
QueryableMethods.Select.MakeGenericMethod(typeArg, anonType),
node.Arguments[0],
lambda);
return @base;
}

var asSubQueryMethod = WindowFunctionsEvaluatableExpressionFilter.AsSubQueryMethod.MakeGenericMethod(anonType);
node = Expression.Call(null, asSubQueryMethod, node);
private static MethodCallExpression HandleSelectLambda(MethodCallExpression node, LambdaExpression originalLambda, Clause queryableMethod)
{
if (queryableMethod.WindowFunctions.Count < 2)
{
return node;
}

p = Expression.Parameter(anonType, "w");
var wfr = new WindowFunctionRewriter(list, whereParam, p);
var body = wfr.Visit(whereLambda.Body);
var newWhere = Expression.Lambda(body, p);
var expression = node.Arguments[0];
var rewrittenLambda = BuildSubqueries(queryableMethod.WindowFunctions, originalLambda, ref expression);

var newWhereMethod = QueryableMethods.Where.MakeGenericMethod(anonType);
node = Expression.Call(null, newWhereMethod, node, newWhere);
// Add .Select<MyAnon>(s => RowNumber(Over().OrderBy(s.P0)))
var newSelectMethod = QueryableMethods.Select.MakeGenericMethod(expression.Type.GenericTypeArguments[0], node.Method.ReturnType.GenericTypeArguments[0]);
var newSelect = Expression.Call(null, newSelectMethod, expression, rewrittenLambda);

// go back to original
p = Expression.Parameter(anonType, "b");
var toOriginal = Expression.Property(p, "Original");
var trailingSelect = Expression.Lambda(toOriginal, p);
var toOriginalMethod = QueryableMethods.Select.MakeGenericMethod(anonType, typeArg);
return newSelect;
}

node = Expression.Call(null, toOriginalMethod, node, trailingSelect);
return node;
}
private static MethodCallExpression HandleLambda(MethodCallExpression node, LambdaExpression originalLambda, Clause queryableMethod)
{
var methodInfo = queryableMethod.Method.Method;
var isWhere = methodInfo.Name == nameof(Queryable.Where);

////return call;
////return Expression.Call(node.Arguments[0], null!);
}
if (queryableMethod.WindowFunctions.Count < 2 && !isWhere)
{
return node;
}

if (WindowFunctionsEvaluatableExpressionFilter.WindowFunctionMethods.Contains(node.Method, CompareNameAndDeclaringType.Default))
var expression = node.Arguments[0];
var rewrittenLambda = BuildSubqueries(queryableMethod.WindowFunctions, originalLambda, ref expression, isWhere);

var anonType = expression.Type.GenericTypeArguments[0];

var newMethod = methodInfo.GetGenericMethodDefinition().MakeGenericMethod([anonType, .. methodInfo.GetGenericArguments()[1..]]);

// Add .Where<MyAnon>(w => (w.P0 == 1))
var newCall = Expression.Call(null, newMethod, expression, rewrittenLambda);

// go back to original
var b = Expression.Parameter(anonType, "b");
var toOriginal = Expression.Property(b, Original);
var trailingSelect = Expression.Lambda(toOriginal, b);
var originalTypeArgument = node.Arguments[0].Type.GenericTypeArguments[0];
var toOriginalMethod = QueryableMethods.Select.MakeGenericMethod(anonType, originalTypeArgument);

// Add .Select<MyAnon>(b => b.Original)
node = Expression.Call(null, toOriginalMethod, newCall, trailingSelect);
return node;
}

private static LambdaExpression BuildSubqueries(
IList<IList<MethodCallExpression>> windowFunctions,
LambdaExpression lambda,
ref Expression expression,
bool isWhere = false)
{
var parameter = lambda.Parameters[0];
var newParameter = parameter;
var lastLevel = isWhere ? 0 : 1;
var subqueryList = windowFunctions[^1];
IDictionary<MethodCallExpression, Name_Type_And_Replacement> replacementMap;
var level = windowFunctions.Count - 1;
WindowFunctionRewriter wfr;

while (true)
{
WindowFunctionsCollection.Add(node);
var typeArg = expression.Type.GenericTypeArguments[0];

var replacements = new List<Name_Type_And_Replacement>();
for (var i = 0; i < subqueryList.Count; ++i)
{
var item = subqueryList[i];
replacements.Add(new($"P{i}", item.Type, item));
}

replacements.Insert(0, new(Original, parameter.Type, null!));

var anonType = CreateNewType(replacements);
var originalProperty = anonType.GetProperty(Original)
?? throw new InvalidOperationException("Can't be null, it was just created.");

Expression originalValue = parameter == newParameter ? parameter : Expression.Property(newParameter, Original);
var mbs = new List<MemberBinding>() { Expression.Bind(originalProperty, originalValue) };

foreach (var item in replacements[1..])
{
var mi = anonType.GetMember(item.Name)[0]
?? throw new InvalidOperationException($"{anonType} doesn't have property");
mbs.Add(Expression.Bind(mi, item.Replacement));
}

var @new = Expression.New(anonType);
var memberInit = Expression.MemberInit(@new, mbs);

var innerSelectLambda = Expression.Lambda(memberInit, newParameter);

// o => .Select<MyAnon>(o => new MyAnon() {Original = o, P0 = __Functions_0.RowNumber(__Over_1.OrderBy(o.Id))})
var node = Expression.Call(
null,
QueryableMethods.Select.MakeGenericMethod(typeArg, anonType),
expression,
innerSelectLambda);

var asSubQueryMethod = WindowFunctionsEvaluatableExpressionFilter.AsSubQueryMethod.MakeGenericMethod(anonType);

// Add .AsSubQuery<MyAnon>()
expression = Expression.Call(null, asSubQueryMethod, node);
replacementMap = windowFunctions[level].Zip(replacements[1..])
.ToDictionary(z => z.First, z => z.Second);
newParameter = Expression.Parameter(expression.Type.GenericTypeArguments[0], $"{(level > 0 ? "l" + (level - 1) : "w")}");

wfr = new(replacementMap, parameter, newParameter);

if (--level < 0)
{
break;
}

var replacing = windowFunctions[level];
subqueryList = replacing.Select(z => (MethodCallExpression)wfr.Visit(z)).ToList();
}

var retVal = base.VisitMethodCall(node);
var newBody = wfr.Visit(lambda.Body);
var rewrittenLambda = Expression.Lambda(newBody, newParameter);

return retVal;
return rewrittenLambda;
}

///// From : https://www.codeproject.com/Articles/121568/Dynamic-Type-Using-Reflection-Emit
Expand Down Expand Up @@ -197,21 +314,17 @@ private static MethodInfo GetMethod(string name, int genericParameterCount, Func

private sealed record Name_Type_And_Replacement(string Name, Type Type, Expression Replacement);

private sealed class WindowFunctionRewriter(Dictionary<Expression, Name_Type_And_Replacement> map, ParameterExpression source, ParameterExpression target) : ExpressionVisitor
private sealed class WindowFunctionRewriter(IDictionary<MethodCallExpression, Name_Type_And_Replacement> map, ParameterExpression source, ParameterExpression target) : ExpressionVisitor
{
private readonly Dictionary<Expression, Name_Type_And_Replacement> map = map;
private readonly ParameterExpression source = source;
private readonly Expression target = target;

protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (map.TryGetValue(node, out var rec))
if (!map.TryGetValue(node, out var rec))
{
var propExpression = Expression.Property(target, rec.Name);
return propExpression;
return base.VisitMethodCall(node);
}

return base.VisitMethodCall(node);
var propExpression = Expression.Property(target, rec.Name);
return propExpression;
}

protected override Expression VisitParameter(ParameterExpression node)
Expand All @@ -221,8 +334,31 @@ protected override Expression VisitParameter(ParameterExpression node)
return base.VisitParameter(node);
}

var newProperty = Expression.Property(target, "Original");
var newProperty = Expression.Property(target, Original);
return newProperty;
}
}

/// <summary>
/// Represents a method which could hold a Window function, such as .Select, .Where, .OdrerBy.
/// </summary>
/// <param name="Method">Method call expression.</param>
private sealed record Clause(MethodCallExpression Method)
{
public Stack<MethodCallExpression> WindowFunctionStack { get; } = new();

public IList<IList<MethodCallExpression>> WindowFunctions { get; } = [];

public Expression FirstArgument => Method.Arguments[0];

public void Add(MethodCallExpression call, int level)
{
while (WindowFunctions.Count == level)
{
WindowFunctions.Add([]);
}

WindowFunctions[level].Add(call);
}
}
}
2 changes: 1 addition & 1 deletion tests/Zomp.EFCore.WindowFunctions.Npgsql.Tests/Partials.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ public partial class CountTests(ITestOutputHelper output) : CountTests<long>(out
public partial class AnalyticTests(ITestOutputHelper output) : TestBase(output) { }

[Collection(nameof(NpgsqlCollection))]
public partial class WhereTests(ITestOutputHelper output) : TestBase(output) { }
public partial class SubQueryTests(ITestOutputHelper output) : TestBase(output) { }
#pragma warning restore SA1402 // File may only contain a single type
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ public partial class CountTests(ITestOutputHelper output) : CountTests<int>(outp
public partial class AnalyticTests(ITestOutputHelper output) : TestBase(output) { }

[Collection(nameof(SqlServerCollection))]
public partial class WhereTests(ITestOutputHelper output) : TestBase(output) { }
public partial class SubQueryTests(ITestOutputHelper output) : TestBase(output) { }
#pragma warning restore SA1402 // File may only contain a single type
2 changes: 1 addition & 1 deletion tests/Zomp.EFCore.WindowFunctions.Sqlite.Tests/Partials.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ public partial class CountTests(ITestOutputHelper output) : CountTests<int>(outp
public partial class AnalyticTests(ITestOutputHelper output) : TestBase(output) { }

[Collection(nameof(SqliteCollection))]
public partial class WhereTests(ITestOutputHelper output) : TestBase(output) { }
public partial class SubQueryTests(ITestOutputHelper output) : TestBase(output) { }
#pragma warning restore SA1402 // File may only contain a single type
Loading

0 comments on commit f46ff96

Please sign in to comment.