From 836976d6e3aad5aa7a7e3a80ff772023673ca8b9 Mon Sep 17 00:00:00 2001 From: Christophe PLAT Date: Thu, 22 Apr 2021 18:18:01 +0200 Subject: [PATCH] Unary and binary operators: find the proper overload by also checking the base classes. --- src/DynamicExpresso.Core/Parsing/Parser.cs | 172 +++++++++++++++--- .../Resources/ErrorMessages.Designer.cs | 24 ++- .../Resources/ErrorMessages.resx | 6 + .../DynamicExpresso.UnitTest/OperatorsTest.cs | 49 +++++ 4 files changed, 222 insertions(+), 29 deletions(-) diff --git a/src/DynamicExpresso.Core/Parsing/Parser.cs b/src/DynamicExpresso.Core/Parsing/Parser.cs index ad8a6f5a..f57ff130 100644 --- a/src/DynamicExpresso.Core/Parsing/Parser.cs +++ b/src/DynamicExpresso.Core/Parsing/Parser.cs @@ -151,7 +151,7 @@ private Expression ParseConditionalOr() NextToken(); var right = ParseConditionalAnd(); CheckAndPromoteOperands(typeof(ParseSignatures.ILogicalSignatures), ref left, ref right); - left = Expression.OrElse(left, right); + left = GenerateBinary(ExpressionType.OrElse, left, right); } return left; } @@ -165,7 +165,7 @@ private Expression ParseConditionalAnd() NextToken(); var right = ParseLogicalOr(); CheckAndPromoteOperands(typeof(ParseSignatures.ILogicalSignatures), ref left, ref right); - left = Expression.AndAlso(left, right); + left = GenerateBinary(ExpressionType.AndAlso, left, right); } return left; } @@ -179,7 +179,7 @@ private Expression ParseLogicalOr() NextToken(); var right = ParseLogicalXor(); CheckAndPromoteOperands(typeof(ParseSignatures.ILogicalSignatures), ref left, ref right); - left = Expression.Or(left, right); + left = GenerateBinary(ExpressionType.Or, left, right); } return left; } @@ -193,7 +193,7 @@ private Expression ParseLogicalXor() NextToken(); var right = ParseLogicalAnd(); CheckAndPromoteOperands(typeof(ParseSignatures.ILogicalSignatures), ref left, ref right); - left = Expression.ExclusiveOr(left, right); + left = GenerateBinary(ExpressionType.ExclusiveOr, left, right); } return left; } @@ -207,7 +207,7 @@ private Expression ParseLogicalAnd() NextToken(); var right = ParseComparison(); CheckAndPromoteOperands(typeof(ParseSignatures.ILogicalSignatures), ref left, ref right); - left = Expression.And(left, right); + left = GenerateBinary(ExpressionType.And, left, right); } return left; } @@ -354,12 +354,12 @@ private Expression ParseAdditive() else { CheckAndPromoteOperands(typeof(ParseSignatures.IAddSignatures), ref left, ref right); - left = GenerateAdd(left, right); + left = GenerateBinary(ExpressionType.Add, left, right); } break; case TokenId.Minus: CheckAndPromoteOperands(typeof(ParseSignatures.ISubtractSignatures), ref left, ref right); - left = GenerateSubtract(left, right); + left = GenerateBinary(ExpressionType.Subtract, left, right); break; } } @@ -382,13 +382,13 @@ private Expression ParseMultiplicative() switch (op.id) { case TokenId.Asterisk: - left = Expression.Multiply(left, right); + left = GenerateBinary(ExpressionType.Multiply, left, right); break; case TokenId.Slash: - left = Expression.Divide(left, right); + left = GenerateBinary(ExpressionType.Divide, left, right); break; case TokenId.Percent: - left = Expression.Modulo(left, right); + left = GenerateBinary(ExpressionType.Modulo, left, right); break; } } @@ -423,7 +423,7 @@ private Expression ParseUnary() if (op.id == TokenId.Minus) { CheckAndPromoteOperand(typeof(ParseSignatures.INegationSignatures), ref expr); - expr = Expression.Negate(expr); + expr = GenerateUnary(ExpressionType.Negate, expr); } else if (op.id == TokenId.Plus) { @@ -432,13 +432,58 @@ private Expression ParseUnary() else if (op.id == TokenId.Exclamation) { CheckAndPromoteOperand(typeof(ParseSignatures.INotSignatures), ref expr); - expr = Expression.Not(expr); + expr = GenerateUnary(ExpressionType.Not, expr); } return expr; } return ParsePrimary(); } + private Expression GenerateUnary(ExpressionType unaryType, Expression expr) + { + // find the overloaded unary operator + string opName; + switch (unaryType) + { + case ExpressionType.Negate: opName = "op_UnaryNegation"; break; + case ExpressionType.Not: opName = "op_LogicalNot"; break; + default: opName = null; break; + } + + var applicableMethod = FindUnaryOperator(opName, expr); + + MethodInfo operatorMethod = null; + if (applicableMethod != null) + { + operatorMethod = applicableMethod.MethodBase as MethodInfo; + expr = applicableMethod.PromotedParameters[0]; + } + + // if no operator was found, the default Linq resolution will occur + return Expression.MakeUnary(unaryType, expr, null, operatorMethod); + } + + private MethodData FindUnaryOperator(string operatorName, Expression expr) + { + if (operatorName == null) + return null; + + var errorPos = _token.pos; + var type = expr.Type; + var args = new[] { expr }; + + // try to find the user defined operator on both operands + var applicableMethods = FindMethods(type, operatorName, true, args); + if (applicableMethods.Length > 1) + throw CreateParseException(errorPos, ErrorMessages.AmbiguousUnaryOperatorInvocation, operatorName, GetTypeName(type)); + + MethodData userDefinedOperator = null; + if (applicableMethods.Length == 1) + userDefinedOperator = applicableMethods[0]; + + return userDefinedOperator; + } + private Expression ParsePrimary() { var tokenPos = _token.pos; @@ -1757,17 +1802,17 @@ private static int CompareConversions(Type s, Type t1, Type t2) return 0; } - private static Expression GenerateEqual(Expression left, Expression right) + private Expression GenerateEqual(Expression left, Expression right) { - return Expression.Equal(left, right); + return GenerateBinary(ExpressionType.Equal, left, right); } - private static Expression GenerateNotEqual(Expression left, Expression right) + private Expression GenerateNotEqual(Expression left, Expression right) { - return Expression.NotEqual(left, right); + return GenerateBinary(ExpressionType.NotEqual, left, right); } - private static Expression GenerateGreaterThan(Expression left, Expression right) + private Expression GenerateGreaterThan(Expression left, Expression right) { if (left.Type == typeof(string)) { @@ -1776,10 +1821,10 @@ private static Expression GenerateGreaterThan(Expression left, Expression right) Expression.Constant(0) ); } - return Expression.GreaterThan(left, right); + return GenerateBinary(ExpressionType.GreaterThan, left, right); } - private static Expression GenerateGreaterThanEqual(Expression left, Expression right) + private Expression GenerateGreaterThanEqual(Expression left, Expression right) { if (left.Type == typeof(string)) { @@ -1788,10 +1833,10 @@ private static Expression GenerateGreaterThanEqual(Expression left, Expression r Expression.Constant(0) ); } - return Expression.GreaterThanOrEqual(left, right); + return GenerateBinary(ExpressionType.GreaterThanOrEqual, left, right); } - private static Expression GenerateLessThan(Expression left, Expression right) + private Expression GenerateLessThan(Expression left, Expression right) { if (left.Type == typeof(string)) { @@ -1801,10 +1846,10 @@ private static Expression GenerateLessThan(Expression left, Expression right) ); } - return Expression.LessThan(left, right); + return GenerateBinary(ExpressionType.LessThan, left, right); } - private static Expression GenerateLessThanEqual(Expression left, Expression right) + private Expression GenerateLessThanEqual(Expression left, Expression right) { if (left.Type == typeof(string)) { @@ -1813,17 +1858,88 @@ private static Expression GenerateLessThanEqual(Expression left, Expression righ Expression.Constant(0) ); } - return Expression.LessThanOrEqual(left, right); + return GenerateBinary(ExpressionType.LessThanOrEqual, left, right); } - private static Expression GenerateAdd(Expression left, Expression right) + private Expression GenerateBinary(ExpressionType binaryType, Expression left, Expression right) { - return Expression.Add(left, right); + // find the overloaded binary operator + string opName; + + var liftToNull = true; + switch (binaryType) + { + case ExpressionType.OrElse: opName = "op_BitwiseOr"; break; + case ExpressionType.Or: opName = "op_BitwiseOr"; break; + case ExpressionType.ExclusiveOr: opName = "op_ExclusiveOr"; break; + case ExpressionType.AndAlso: opName = "op_BitwiseAnd"; break; + case ExpressionType.And: opName = "op_BitwiseAnd"; break; + case ExpressionType.Add: opName = "op_Addition"; break; + case ExpressionType.Subtract: opName = "op_Subtraction"; break; + case ExpressionType.Multiply: opName = "op_Multiply"; break; + case ExpressionType.Divide: opName = "op_Division"; break; + case ExpressionType.Modulo: opName = "op_Modulus"; break; + case ExpressionType.Equal: opName = "op_Equality"; liftToNull = false; break; + case ExpressionType.NotEqual: opName = "op_Inequality"; liftToNull = false; break; + case ExpressionType.GreaterThan: opName = "op_GreaterThan"; liftToNull = false; break; + case ExpressionType.GreaterThanOrEqual: opName = "op_GreaterThanOrEqual"; liftToNull = false; break; + case ExpressionType.LessThan: opName = "op_LessThan"; liftToNull = false; break; + case ExpressionType.LessThanOrEqual: opName = "op_LessThanOrEqual"; liftToNull = false; break; + default: opName = null; break; + } + + var applicableMethod = FindBinaryOperator(opName, left, right); + + MethodInfo operatorMethod = null; + if (applicableMethod != null) + { + operatorMethod = applicableMethod.MethodBase as MethodInfo; + left = applicableMethod.PromotedParameters[0]; + right = applicableMethod.PromotedParameters[1]; + } + + // if no operator was found, the default Linq resolution will occur + return Expression.MakeBinary(binaryType, left, right, liftToNull, operatorMethod); } - private static Expression GenerateSubtract(Expression left, Expression right) + private MethodData FindBinaryOperator(string operatorName, Expression left, Expression right) { - return Expression.Subtract(left, right); + if (operatorName == null) + return null; + + var errorPos = _token.pos; + var leftType = left.Type; + var rightType = right.Type; + var error = CreateParseException(errorPos, ErrorMessages.AmbiguousBinaryOperatorInvocation, operatorName, GetTypeName(leftType), GetTypeName(rightType)); + + var args = new[] { left, right }; + + MethodData userDefinedOperator = null; + + // try to find the user defined operator on both operands + var opOnLeftType = FindMethods(leftType, operatorName, true, args); + if (opOnLeftType.Length > 1) + throw error; + + if (opOnLeftType.Length == 1) + userDefinedOperator = opOnLeftType[0]; + + if (leftType != rightType) + { + var opOnRightType = FindMethods(rightType, operatorName, true, args); + if (opOnRightType.Length > 1) + throw error; + + MethodData rightOperator = null; + if (opOnRightType.Length == 1) + rightOperator = opOnRightType[0]; + + // we found a matching user defined operator on either type, but it might be the same method + if (userDefinedOperator != null && rightOperator != null && !ReferenceEquals(userDefinedOperator.MethodBase, rightOperator.MethodBase)) + throw error; + } + + return userDefinedOperator; } private static Expression GenerateStringConcat(Expression left, Expression right) diff --git a/src/DynamicExpresso.Core/Resources/ErrorMessages.Designer.cs b/src/DynamicExpresso.Core/Resources/ErrorMessages.Designer.cs index 8e61887d..9e48c3bc 100644 --- a/src/DynamicExpresso.Core/Resources/ErrorMessages.Designer.cs +++ b/src/DynamicExpresso.Core/Resources/ErrorMessages.Designer.cs @@ -77,7 +77,29 @@ internal static string AmbiguousMethodInvocation { return ResourceManager.GetString("AmbiguousMethodInvocation", resourceCulture); } } - + + /// + /// Looks up a localized string similar to Ambiguous invocation of user defined operator '{0}' in type '{1}'. + /// + internal static string AmbiguousUnaryOperatorInvocation + { + get + { + return ResourceManager.GetString("AmbiguousUnaryOperatorInvocation", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Ambiguous invocation of user defined operator '{0}' in types '{1}' and '{2}'. + /// + internal static string AmbiguousBinaryOperatorInvocation + { + get + { + return ResourceManager.GetString("AmbiguousBinaryOperatorInvocation", resourceCulture); + } + } + /// /// Looks up a localized string similar to Argument list incompatible with delegate expression. /// diff --git a/src/DynamicExpresso.Core/Resources/ErrorMessages.resx b/src/DynamicExpresso.Core/Resources/ErrorMessages.resx index 771fc040..875b4ae8 100644 --- a/src/DynamicExpresso.Core/Resources/ErrorMessages.resx +++ b/src/DynamicExpresso.Core/Resources/ErrorMessages.resx @@ -234,4 +234,10 @@ Invalid Operation + + Ambiguous invocation of user defined operator '{0}' in types '{1}' and '{2}' + + + Ambiguous invocation of user defined operator '{0}' in type '{1}' + \ No newline at end of file diff --git a/test/DynamicExpresso.UnitTest/OperatorsTest.cs b/test/DynamicExpresso.UnitTest/OperatorsTest.cs index 04447163..71610ba3 100644 --- a/test/DynamicExpresso.UnitTest/OperatorsTest.cs +++ b/test/DynamicExpresso.UnitTest/OperatorsTest.cs @@ -581,6 +581,55 @@ public override int GetHashCode() } } + private class DerivedClassWithOverloadedBinaryOperators : ClassWithOverloadedBinaryOperators + { + public DerivedClassWithOverloadedBinaryOperators(int value) : base(value) + { + } + } + + [Test] + public void Can_use_overloaded_operators_on_derived_class() + { + var target = new Interpreter(); + + var x = new DerivedClassWithOverloadedBinaryOperators(3); + target.SetVariable("x", x); + + string y = "5"; + Assert.IsFalse(x == y); + Assert.IsFalse(target.Eval("x == y", new Parameter("y", y))); + + y = "3"; + Assert.IsTrue(x == y); + Assert.IsTrue(target.Eval("x == y", new Parameter("y", y))); + + Assert.IsFalse(target.Eval("x == \"4\"")); + Assert.IsTrue(target.Eval("x == \"3\"")); + + Assert.IsTrue(!x == "-3"); + Assert.IsTrue(target.Eval("!x == \"-3\"")); + + var z = new DerivedClassWithOverloadedBinaryOperators(10); + Assert.IsTrue((x + z) == "13"); + Assert.IsTrue(target.Eval("(x + z) == \"13\"", new Parameter("z", z))); + } + + [Test] + public void Can_mix_overloaded_operators() + { + var target = new Interpreter(); + + var x = new ClassWithOverloadedBinaryOperators(3); + target.SetVariable("x", x); + + // ensure we don't trigger an ambiguous operator exception + var z = new DerivedClassWithOverloadedBinaryOperators(10); + Assert.IsTrue((x + z) == "13"); + Assert.IsTrue(target.Eval("(x + z) == \"13\"", new Parameter("z", z))); + } + + [Test] public void Throw_an_exception_if_a_custom_type_doesnt_define_equal_operator() {