From 2866ce8be89fff138365a9e9697739b49dc78c0b Mon Sep 17 00:00:00 2001 From: "Jeremy D. Miller" Date: Wed, 12 Jul 2023 13:13:52 -0500 Subject: [PATCH] Booyah! First successful usage of the new Linq parsing w/ WHERE clauses! --- src/LinqTests/playing.cs | 2 +- src/Marten/Linq/New/CollectionUsage.cs | 2 +- src/Marten/Linq/New/DocumentCollection.cs | 14 ++ src/Marten/Linq/New/IQueryableCollection.cs | 7 +- .../Linq/New/Members/IQueryableMember.cs | 10 +- src/Marten/Linq/New/Members/NumericMember.cs | 19 +- .../Linq/New/Members/QueryableMember.cs | 13 +- src/Marten/Linq/New/NewWhereClauseParser.cs | 227 ++++++++++++++++++ .../Linq/New/Operators/WhereOperator.cs | 6 +- src/Marten/Linq/New/WhereClauseParser.cs | 45 ---- .../Parsing/WhereClauseParser.BinarySide.cs | 2 +- 11 files changed, 293 insertions(+), 54 deletions(-) create mode 100644 src/Marten/Linq/New/NewWhereClauseParser.cs delete mode 100644 src/Marten/Linq/New/WhereClauseParser.cs diff --git a/src/LinqTests/playing.cs b/src/LinqTests/playing.cs index 33b7c788c7..cf3626567b 100644 --- a/src/LinqTests/playing.cs +++ b/src/LinqTests/playing.cs @@ -58,7 +58,7 @@ public async Task try_linq() var targets = await theSession.Query() .OrderBy(x => x.Double) - .Where(x => x.Number == 5) + .Where(x => x.Number == 6) .Take(5) .ToListAsync(); diff --git a/src/Marten/Linq/New/CollectionUsage.cs b/src/Marten/Linq/New/CollectionUsage.cs index 02f4b9e395..535313cff1 100644 --- a/src/Marten/Linq/New/CollectionUsage.cs +++ b/src/Marten/Linq/New/CollectionUsage.cs @@ -47,7 +47,7 @@ public NewStatement BuildStatement(DocumentCollection collection) } // TODO -- watch for "ANDS" - var parser = new WhereClauseParser(collection, statement); + var parser = new NewWhereClauseParser(collection, statement); foreach (var expression in Wheres) { parser.Visit(expression); diff --git a/src/Marten/Linq/New/DocumentCollection.cs b/src/Marten/Linq/New/DocumentCollection.cs index a6ff9db61c..b14883395c 100644 --- a/src/Marten/Linq/New/DocumentCollection.cs +++ b/src/Marten/Linq/New/DocumentCollection.cs @@ -1,4 +1,6 @@ using System; +using System.Linq; +using System.Linq.Expressions; using System.Reflection; using JasperFx.Core; using JasperFx.Core.Reflection; @@ -37,6 +39,18 @@ public IQueryableMember FindMember(MemberInfo member) return m; } + public IQueryableMember MemberFor(Expression expression) + { + // TODO -- NEEDS TO GET MORE COMPLICATED LATER + + var members = FindMembers.Determine(expression); + + if (members.Length != 1) + throw new NotImplementedException("Hey, deep accessors aren't yet supported in the new Linq provider"); + + return FindMember(members.Single()); + } + private IQueryableMember createQuerableMember(MemberInfo member) { // TODO -- this will need to get MUCH fancier later diff --git a/src/Marten/Linq/New/IQueryableCollection.cs b/src/Marten/Linq/New/IQueryableCollection.cs index ed7c0d3068..61e0aebb31 100644 --- a/src/Marten/Linq/New/IQueryableCollection.cs +++ b/src/Marten/Linq/New/IQueryableCollection.cs @@ -1,4 +1,5 @@ using System; +using System.Linq.Expressions; using System.Reflection; using Marten.Linq.New.Members; @@ -9,4 +10,8 @@ public interface IQueryableCollection Type ElementType { get; } IQueryableMember FindMember(MemberInfo member); -} \ No newline at end of file + // Assume there's a separate member for the usage of a member with methods + // i.e., StringMember.ToLower() + // Dictionary fields will create a separate "dictionary value field" + IQueryableMember MemberFor(Expression expression); +} diff --git a/src/Marten/Linq/New/Members/IQueryableMember.cs b/src/Marten/Linq/New/Members/IQueryableMember.cs index 90042fe385..d29bca6952 100644 --- a/src/Marten/Linq/New/Members/IQueryableMember.cs +++ b/src/Marten/Linq/New/Members/IQueryableMember.cs @@ -1,10 +1,13 @@ using System; +using System.Linq.Expressions; using System.Reflection; using Marten.Linq.New.Operators; +using Weasel.Postgresql.SqlGeneration; namespace Marten.Linq.New.Members; -public interface IQueryableMember + +public interface IQueryableMember : ISqlFragment { MemberInfo Member { get; } Type MemberType { get; } @@ -26,3 +29,8 @@ public interface IQueryableMember /// string BuildOrderingExpression(NewOrdering ordering); } + +public interface IComparableMember +{ + ISqlFragment CreateComparison(string op, ConstantExpression constant); +} diff --git a/src/Marten/Linq/New/Members/NumericMember.cs b/src/Marten/Linq/New/Members/NumericMember.cs index 0ef2fae350..726665c4ba 100644 --- a/src/Marten/Linq/New/Members/NumericMember.cs +++ b/src/Marten/Linq/New/Members/NumericMember.cs @@ -1,10 +1,13 @@ +using System; +using System.Linq.Expressions; using System.Reflection; using Weasel.Core; using Weasel.Postgresql; +using Weasel.Postgresql.SqlGeneration; namespace Marten.Linq.New.Members; -internal class NumericMember: QueryableMember +internal class NumericMember: QueryableMember, IComparableMember { public NumericMember(string parentLocator, Casing casing, MemberInfo member) : base(parentLocator, casing, member) { @@ -15,4 +18,16 @@ public NumericMember(string parentLocator, Casing casing, MemberInfo member) : b TypedLocator = $"CAST({RawLocator} as {pgType})"; } -} \ No newline at end of file + + public ISqlFragment CreateComparison(string op, ConstantExpression constant) + { + if (constant.Value == null) + { + throw new NotImplementedException("Not yet"); + //return op == "=" ? new IsNullFilter(this) : new IsNotNullFilter(this); + } + + var def = new CommandParameter(constant); + return new ComparisonFilter(this, def, op); + } +} diff --git a/src/Marten/Linq/New/Members/QueryableMember.cs b/src/Marten/Linq/New/Members/QueryableMember.cs index 40258fe521..03aa705a3b 100644 --- a/src/Marten/Linq/New/Members/QueryableMember.cs +++ b/src/Marten/Linq/New/Members/QueryableMember.cs @@ -4,6 +4,8 @@ using Marten.Linq.New.Operators; using Marten.Util; using Remotion.Linq.Clauses; +using Weasel.Postgresql; +using Weasel.Postgresql.SqlGeneration; namespace Marten.Linq.New.Members; @@ -40,4 +42,13 @@ public virtual string BuildOrderingExpression(NewOrdering ordering) } -} \ No newline at end of file + void ISqlFragment.Apply(CommandBuilder builder) + { + builder.Append(TypedLocator); + } + + bool ISqlFragment.Contains(string sqlText) + { + return false; + } +} diff --git a/src/Marten/Linq/New/NewWhereClauseParser.cs b/src/Marten/Linq/New/NewWhereClauseParser.cs new file mode 100644 index 0000000000..37f2ed5e8a --- /dev/null +++ b/src/Marten/Linq/New/NewWhereClauseParser.cs @@ -0,0 +1,227 @@ +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using Marten.Exceptions; +using Marten.Linq.Fields; +using Marten.Linq.New.Members; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; +using Weasel.Postgresql.SqlGeneration; + +namespace Marten.Linq.New; + +public partial class NewWhereClauseParser : RelinqExpressionVisitor +{ + private static readonly Dictionary _operators = new() + { + { ExpressionType.Equal, "=" }, + { ExpressionType.NotEqual, "!=" }, + { ExpressionType.GreaterThan, ">" }, + { ExpressionType.GreaterThanOrEqual, ">=" }, + { ExpressionType.LessThan, "<" }, + { ExpressionType.LessThanOrEqual, "<=" } + }; + + + private readonly IQueryableCollection _collection; + private IWhereFragmentHolder _holder; + + public NewWhereClauseParser(IQueryableCollection collection, IWhereFragmentHolder holder) + { + _collection = collection; + _holder = holder; + } + + protected override Expression VisitNew(NewExpression expression) + { + return base.VisitNew(expression); + } + + protected override Expression VisitSubQuery(SubQueryExpression expression) + { + return base.VisitSubQuery(expression); + } + + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) + { + return base.VisitQuerySourceReference(expression); + } + + protected override Expression VisitBinary(BinaryExpression node) + { + if (_operators.TryGetValue(node.NodeType, out var op)) + { + var binary = new BinaryExpressionVisitor(this); + _holder.Register(binary.BuildWhereFragment(node, op)); + + return null; + } + + switch (node.NodeType) + { + case ExpressionType.AndAlso: + throw new NotImplementedException("Not yet"); + //buildCompoundWhereFragment(node, "and"); + break; + + case ExpressionType.OrElse: + throw new NotImplementedException("Not yet"); + //buildCompoundWhereFragment(node, "or"); + break; + + default: + throw new BadLinqExpressionException( + $"Unsupported expression type '{node.NodeType}' in binary expression"); + } + + + return null; + } + + protected override Expression VisitUnary(UnaryExpression node) + { + if (node.NodeType == ExpressionType.Not) + { + var original = _holder; + + if (original is IReversibleWhereFragment r) + { + _holder.Register(r.Reverse()); + return Visit(node.Operand); + } + + _holder = new NotWhereFragment(original); + var returnValue = Visit(node.Operand); + + _holder = original; + + return returnValue; + } + + return null; + } + + internal class BinaryExpressionVisitor: RelinqExpressionVisitor + { + private readonly NewWhereClauseParser _parent; + private BinarySide _left; + private BinarySide _right; + + public BinaryExpressionVisitor(NewWhereClauseParser parent) + { + _parent = parent; + } + + public ISqlFragment BuildWhereFragment(BinaryExpression node, string op) + { + _left = analyze(node.Left); + _right = analyze(node.Right); + + return _left.CompareTo(_right, op); + } + + private BinarySide analyze(Expression expression) + { + switch (expression) + { + case ConstantExpression c: + return new BinarySide(expression) { Constant = c }; + case PartialEvaluationExceptionExpression p: + { + var inner = p.Exception; + + throw new BadLinqExpressionException( + $"Error in value expression inside of the query for '{p.EvaluatedExpression}'. See the inner exception:", + inner); + } + case SubQueryExpression subQuery: + { + throw new NotImplementedException("Not yet"); + // var parser = new WhereClauseParser.SubQueryFilterParser(_parent, subQuery); + // + // return new BinarySide(expression) { Comparable = parser.BuildCountComparisonStatement() }; + } + case QuerySourceReferenceExpression source: + throw new NotImplementedException("Not yet"); + //return new BinarySide(expression) { Member = new SimpleDataField(source.Type) }; + case BinaryExpression { NodeType: ExpressionType.Modulo } binary: + throw new NotImplementedException("Not yet"); + // return new BinarySide(expression) + // { + // Comparable = new ModuloFragment(binary, _parent._statement.Fields) + // }; + + case BinaryExpression { NodeType: ExpressionType.NotEqual } ne: + throw new NotImplementedException("Not yet"); + // if (ne.Right is ConstantExpression v && v.Value == null) + // { + // var field = _parent._statement.Fields.FieldFor(ne.Left); + // return new BinarySide(expression) { Comparable = new HasValueField(field) }; + // } + + throw new BadLinqExpressionException("Invalid Linq Where() clause with expression: " + ne); + case BinaryExpression binary: + throw new BadLinqExpressionException( + $"Unsupported nested operator '{binary.NodeType}' as an operand in a binary expression"); + + case UnaryExpression u when u.NodeType == ExpressionType.Not: + throw new NotImplementedException("Not yet"); + // return new BinarySide(expression) + // { + // Comparable = new NotField(_parent._statement.Fields.FieldFor(u.Operand)) + // }; + + default: + return new BinarySide(expression) { Member = _parent._collection.MemberFor(expression) }; + } + } + + + + } + + internal class BinarySide + { + public BinarySide(Expression memberExpression) + { + MemberExpression = memberExpression as MemberExpression; + } + + public ConstantExpression Constant { get; set; } + public IQueryableMember Member { get; set; } + + public MemberExpression MemberExpression { get; } + + public ISqlFragment CompareTo(BinarySide right, string op) + { + if (Constant != null) + { + return right.CompareTo(this, ComparisonFilter.OppositeOperators[op]); + } + + if (Member is IComparableMember m && right.Constant != null) + { + return m.CreateComparison(op, right.Constant); + } + + if (Member == null) + { + throw new BadLinqExpressionException("Unsupported binary value expression in a Where() clause"); + } + + if (right.Constant != null && Member is IComparableMember comparableMember) + { + return comparableMember.CreateComparison(op, right.Constant); + } + + if (right.Member != null) + { + // TODO -- this will need to evaluate extra methods in the comparison. Looking for StringProp.ToLower() == "foo" + return new ComparisonFilter(Member, right.Member, op); + } + + + throw new BadLinqExpressionException("Unsupported binary value expression in a Where() clause"); + } + } +} diff --git a/src/Marten/Linq/New/Operators/WhereOperator.cs b/src/Marten/Linq/New/Operators/WhereOperator.cs index 4b6029d104..783a7f37af 100644 --- a/src/Marten/Linq/New/Operators/WhereOperator.cs +++ b/src/Marten/Linq/New/Operators/WhereOperator.cs @@ -34,7 +34,11 @@ public WhereOperator() : base("Where") public override MethodCallExpression Apply(ILinqQuery query, MethodCallExpression expression) { var usage = query.CollectionUsageFor(expression); - usage.Wheres.Add(expression.Arguments.Last()); + var where = expression.Arguments.Last(); + if (where is UnaryExpression e) where = e.Operand; + if (where is LambdaExpression l) where = l.Body; + + usage.Wheres.Add(where); var elementType = usage.ElementType; return Expression.Call(expression.Method, expression.Arguments[0], Expression.Constant(_always[elementType])); diff --git a/src/Marten/Linq/New/WhereClauseParser.cs b/src/Marten/Linq/New/WhereClauseParser.cs deleted file mode 100644 index f213965848..0000000000 --- a/src/Marten/Linq/New/WhereClauseParser.cs +++ /dev/null @@ -1,45 +0,0 @@ -using System.Linq.Expressions; -using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Parsing; -using Weasel.Postgresql.SqlGeneration; - -namespace Marten.Linq.New; - -public class WhereClauseParser : RelinqExpressionVisitor -{ - private readonly IQueryableCollection _collection; - private readonly IWhereFragmentHolder _holder; - - public WhereClauseParser(IQueryableCollection collection, IWhereFragmentHolder holder) - { - _collection = collection; - _holder = holder; - } - - protected override Expression VisitNew(NewExpression expression) - { - return base.VisitNew(expression); - } - - protected override Expression VisitSubQuery(SubQueryExpression expression) - { - return base.VisitSubQuery(expression); - } - - protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) - { - return base.VisitQuerySourceReference(expression); - } - - protected override Expression VisitBinary(BinaryExpression node) - { - return base.VisitBinary(node); - } - - protected override Expression VisitUnary(UnaryExpression node) - { - return base.VisitUnary(node); - } - - -} diff --git a/src/Marten/Linq/Parsing/WhereClauseParser.BinarySide.cs b/src/Marten/Linq/Parsing/WhereClauseParser.BinarySide.cs index 01cfe24dab..43a7d98730 100644 --- a/src/Marten/Linq/Parsing/WhereClauseParser.BinarySide.cs +++ b/src/Marten/Linq/Parsing/WhereClauseParser.BinarySide.cs @@ -19,7 +19,7 @@ public BinarySide(Expression memberExpression) public IField Field { get; set; } public IComparableFragment Comparable { get; set; } - public Expression MemberExpression { get; }s + public Expression MemberExpression { get; } public ISqlFragment CompareTo(BinarySide right, string op) {