Skip to content

Commit

Permalink
Booyah! First successful usage of the new Linq parsing w/ WHERE clauses!
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremydmiller committed Jul 12, 2023
1 parent afe3807 commit 2866ce8
Show file tree
Hide file tree
Showing 11 changed files with 293 additions and 54 deletions.
2 changes: 1 addition & 1 deletion src/LinqTests/playing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public async Task try_linq()

var targets = await theSession.Query<Target>()
.OrderBy(x => x.Double)
.Where(x => x.Number == 5)
.Where(x => x.Number == 6)
.Take(5)
.ToListAsync();

Expand Down
2 changes: 1 addition & 1 deletion src/Marten/Linq/New/CollectionUsage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public NewStatement BuildStatement(DocumentCollection collection)
}

// TODO -- watch for "ANDS"
var parser = new WhereClauseParser(collection, statement);
var parser = new NewWhereClauseParser(collection, statement);
foreach (var expression in Wheres)
{
parser.Visit(expression);
Expand Down
14 changes: 14 additions & 0 deletions src/Marten/Linq/New/DocumentCollection.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using JasperFx.Core;
using JasperFx.Core.Reflection;
Expand Down Expand Up @@ -37,6 +39,18 @@ public IQueryableMember FindMember(MemberInfo member)
return m;
}

public IQueryableMember MemberFor(Expression expression)
{
// TODO -- NEEDS TO GET MORE COMPLICATED LATER

var members = FindMembers.Determine(expression);

if (members.Length != 1)
throw new NotImplementedException("Hey, deep accessors aren't yet supported in the new Linq provider");

return FindMember(members.Single());
}

private IQueryableMember createQuerableMember(MemberInfo member)
{
// TODO -- this will need to get MUCH fancier later
Expand Down
7 changes: 6 additions & 1 deletion src/Marten/Linq/New/IQueryableCollection.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Linq.Expressions;
using System.Reflection;
using Marten.Linq.New.Members;

Expand All @@ -9,4 +10,8 @@ public interface IQueryableCollection
Type ElementType { get; }
IQueryableMember FindMember(MemberInfo member);

}
// Assume there's a separate member for the usage of a member with methods
// i.e., StringMember.ToLower()
// Dictionary fields will create a separate "dictionary value field"
IQueryableMember MemberFor(Expression expression);
}
10 changes: 9 additions & 1 deletion src/Marten/Linq/New/Members/IQueryableMember.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using System;
using System.Linq.Expressions;
using System.Reflection;
using Marten.Linq.New.Operators;
using Weasel.Postgresql.SqlGeneration;

namespace Marten.Linq.New.Members;

public interface IQueryableMember

public interface IQueryableMember : ISqlFragment
{
MemberInfo Member { get; }
Type MemberType { get; }
Expand All @@ -26,3 +29,8 @@ public interface IQueryableMember
/// <returns></returns>
string BuildOrderingExpression(NewOrdering ordering);
}

public interface IComparableMember
{
ISqlFragment CreateComparison(string op, ConstantExpression constant);
}
19 changes: 17 additions & 2 deletions src/Marten/Linq/New/Members/NumericMember.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using System;
using System.Linq.Expressions;
using System.Reflection;
using Weasel.Core;
using Weasel.Postgresql;
using Weasel.Postgresql.SqlGeneration;

namespace Marten.Linq.New.Members;

internal class NumericMember: QueryableMember
internal class NumericMember: QueryableMember, IComparableMember
{
public NumericMember(string parentLocator, Casing casing, MemberInfo member) : base(parentLocator, casing, member)
{
Expand All @@ -15,4 +18,16 @@ public NumericMember(string parentLocator, Casing casing, MemberInfo member) : b
TypedLocator = $"CAST({RawLocator} as {pgType})";

}
}

public ISqlFragment CreateComparison(string op, ConstantExpression constant)
{
if (constant.Value == null)
{
throw new NotImplementedException("Not yet");
//return op == "=" ? new IsNullFilter(this) : new IsNotNullFilter(this);
}

var def = new CommandParameter(constant);
return new ComparisonFilter(this, def, op);
}
}
13 changes: 12 additions & 1 deletion src/Marten/Linq/New/Members/QueryableMember.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using Marten.Linq.New.Operators;
using Marten.Util;
using Remotion.Linq.Clauses;
using Weasel.Postgresql;
using Weasel.Postgresql.SqlGeneration;

namespace Marten.Linq.New.Members;

Expand Down Expand Up @@ -40,4 +42,13 @@ public virtual string BuildOrderingExpression(NewOrdering ordering)
}


}
void ISqlFragment.Apply(CommandBuilder builder)
{
builder.Append(TypedLocator);
}

bool ISqlFragment.Contains(string sqlText)
{
return false;
}
}
227 changes: 227 additions & 0 deletions src/Marten/Linq/New/NewWhereClauseParser.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using Marten.Exceptions;
using Marten.Linq.Fields;
using Marten.Linq.New.Members;
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Parsing;
using Weasel.Postgresql.SqlGeneration;

namespace Marten.Linq.New;

public partial class NewWhereClauseParser : RelinqExpressionVisitor
{
private static readonly Dictionary<ExpressionType, string> _operators = new()
{
{ ExpressionType.Equal, "=" },
{ ExpressionType.NotEqual, "!=" },
{ ExpressionType.GreaterThan, ">" },
{ ExpressionType.GreaterThanOrEqual, ">=" },
{ ExpressionType.LessThan, "<" },
{ ExpressionType.LessThanOrEqual, "<=" }
};


private readonly IQueryableCollection _collection;
private IWhereFragmentHolder _holder;

public NewWhereClauseParser(IQueryableCollection collection, IWhereFragmentHolder holder)
{
_collection = collection;
_holder = holder;
}

protected override Expression VisitNew(NewExpression expression)
{
return base.VisitNew(expression);
}

protected override Expression VisitSubQuery(SubQueryExpression expression)
{
return base.VisitSubQuery(expression);
}

protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression)
{
return base.VisitQuerySourceReference(expression);
}

protected override Expression VisitBinary(BinaryExpression node)
{
if (_operators.TryGetValue(node.NodeType, out var op))
{
var binary = new BinaryExpressionVisitor(this);
_holder.Register(binary.BuildWhereFragment(node, op));

return null;
}

switch (node.NodeType)
{
case ExpressionType.AndAlso:
throw new NotImplementedException("Not yet");
//buildCompoundWhereFragment(node, "and");
break;

case ExpressionType.OrElse:
throw new NotImplementedException("Not yet");
//buildCompoundWhereFragment(node, "or");
break;

default:
throw new BadLinqExpressionException(
$"Unsupported expression type '{node.NodeType}' in binary expression");
}


return null;
}

protected override Expression VisitUnary(UnaryExpression node)
{
if (node.NodeType == ExpressionType.Not)
{
var original = _holder;

if (original is IReversibleWhereFragment r)
{
_holder.Register(r.Reverse());
return Visit(node.Operand);
}

_holder = new NotWhereFragment(original);
var returnValue = Visit(node.Operand);

_holder = original;

return returnValue;
}

return null;
}

internal class BinaryExpressionVisitor: RelinqExpressionVisitor
{
private readonly NewWhereClauseParser _parent;
private BinarySide _left;
private BinarySide _right;

public BinaryExpressionVisitor(NewWhereClauseParser parent)
{
_parent = parent;
}

public ISqlFragment BuildWhereFragment(BinaryExpression node, string op)
{
_left = analyze(node.Left);
_right = analyze(node.Right);

return _left.CompareTo(_right, op);
}

private BinarySide analyze(Expression expression)
{
switch (expression)
{
case ConstantExpression c:
return new BinarySide(expression) { Constant = c };
case PartialEvaluationExceptionExpression p:
{
var inner = p.Exception;

throw new BadLinqExpressionException(
$"Error in value expression inside of the query for '{p.EvaluatedExpression}'. See the inner exception:",
inner);
}
case SubQueryExpression subQuery:
{
throw new NotImplementedException("Not yet");
// var parser = new WhereClauseParser.SubQueryFilterParser(_parent, subQuery);
//
// return new BinarySide(expression) { Comparable = parser.BuildCountComparisonStatement() };
}
case QuerySourceReferenceExpression source:
throw new NotImplementedException("Not yet");
//return new BinarySide(expression) { Member = new SimpleDataField(source.Type) };
case BinaryExpression { NodeType: ExpressionType.Modulo } binary:
throw new NotImplementedException("Not yet");
// return new BinarySide(expression)
// {
// Comparable = new ModuloFragment(binary, _parent._statement.Fields)
// };

case BinaryExpression { NodeType: ExpressionType.NotEqual } ne:
throw new NotImplementedException("Not yet");
// if (ne.Right is ConstantExpression v && v.Value == null)
// {
// var field = _parent._statement.Fields.FieldFor(ne.Left);
// return new BinarySide(expression) { Comparable = new HasValueField(field) };
// }

throw new BadLinqExpressionException("Invalid Linq Where() clause with expression: " + ne);
case BinaryExpression binary:
throw new BadLinqExpressionException(
$"Unsupported nested operator '{binary.NodeType}' as an operand in a binary expression");

case UnaryExpression u when u.NodeType == ExpressionType.Not:
throw new NotImplementedException("Not yet");
// return new BinarySide(expression)
// {
// Comparable = new NotField(_parent._statement.Fields.FieldFor(u.Operand))
// };

default:
return new BinarySide(expression) { Member = _parent._collection.MemberFor(expression) };
}
}



}

internal class BinarySide
{
public BinarySide(Expression memberExpression)
{
MemberExpression = memberExpression as MemberExpression;
}

public ConstantExpression Constant { get; set; }
public IQueryableMember Member { get; set; }

public MemberExpression MemberExpression { get; }

public ISqlFragment CompareTo(BinarySide right, string op)
{
if (Constant != null)
{
return right.CompareTo(this, ComparisonFilter.OppositeOperators[op]);
}

if (Member is IComparableMember m && right.Constant != null)
{
return m.CreateComparison(op, right.Constant);
}

if (Member == null)
{
throw new BadLinqExpressionException("Unsupported binary value expression in a Where() clause");
}

if (right.Constant != null && Member is IComparableMember comparableMember)
{
return comparableMember.CreateComparison(op, right.Constant);
}

if (right.Member != null)
{
// TODO -- this will need to evaluate extra methods in the comparison. Looking for StringProp.ToLower() == "foo"
return new ComparisonFilter(Member, right.Member, op);
}


throw new BadLinqExpressionException("Unsupported binary value expression in a Where() clause");
}
}
}
6 changes: 5 additions & 1 deletion src/Marten/Linq/New/Operators/WhereOperator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ public WhereOperator() : base("Where")
public override MethodCallExpression Apply(ILinqQuery query, MethodCallExpression expression)
{
var usage = query.CollectionUsageFor(expression);
usage.Wheres.Add(expression.Arguments.Last());
var where = expression.Arguments.Last();
if (where is UnaryExpression e) where = e.Operand;
if (where is LambdaExpression l) where = l.Body;

usage.Wheres.Add(where);

var elementType = usage.ElementType;
return Expression.Call(expression.Method, expression.Arguments[0], Expression.Constant(_always[elementType]));
Expand Down
Loading

0 comments on commit 2866ce8

Please sign in to comment.