Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C# like handling of null in comparation operators #71

Open
Yankes opened this issue Sep 21, 2022 · 3 comments
Open

C# like handling of null in comparation operators #71

Yankes opened this issue Sep 21, 2022 · 3 comments

Comments

@Yankes
Copy link

Yankes commented Sep 21, 2022

In C# when you write code alike (int?)0 == null result is false but in NCalc2 because of LambdaExpressionVistor.UnwrapNullable all null are coveted to 0.

This in effect cause NullInt() == 0 return true.

As some user expressions could relying on this behavior I think best would be new option flag for this behavior.
When this flag is enable then in LambdaExpressionVistor.WithCommonNumericType add code like:

        private bool IsNullable(L.Expression expression)
        {
            var ti = expression.Type.GetTypeInfo();
            return ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(Nullable<>);
        }

        private L.Expression HaveValueExpression(L.Expression expression)
        {
            if (IsNullable(expression))
            {
                return L.Expression.Property(expression, "HasValue");
            }
            else
            {
                return L.Expression.Constant(true);
            }
        }

        private Func<L.Expression, L.Expression, L.Expression> UnwrapNullableAction(L.Expression left, L.Expression rigth, Func<L.Expression, L.Expression, L.Expression> action, BinaryExpressionType expressiontype)
        {
            if (/*TODO should we use null like C#? */ && (IsNullable(left) || IsNullable(rigth)))
            {
                if (/*TODO is comparison expressiontype */)
                {
                    return (l, r) => L.Expression.Condition(
                        L.Expression.Equal(HaveValueExpression(left), HaveValueExpression(rigth)),
                        action(l, r), // will compare `0 == 0` instead of `null == null` but this is fine
                        L.Expression.Constant(false));
                }
                else
                {
                    return (l, r) => L.Expression.Condition(
                        L.Expression.And(HaveValueExpression(left), HaveValueExpression(rigth)),
                        action(l, r),
                        L.Expression.Constant(null));
                }
            }
            else
            {
                return action;
            }
        }

        private L.Expression WithCommonNumericType(L.Expression left, L.Expression right,
            Func<L.Expression, L.Expression, L.Expression> action, BinaryExpressionType expressiontype = BinaryExpressionType.Unknown)
        {
            action = UnwrapNullableAction(left, right, action, expressiontype);
            left = UnwrapNullable(left);
            right = UnwrapNullable(right);

            /*TODO rest of logic */
        }
@Yankes
Copy link
Author

Yankes commented Sep 21, 2022

For my poroject I hack LambdaExpressionVistor to better work wilth nulls:

        internal class HackExpressionVistor : LogicalExpressionVisitor
        {
            private readonly IDictionary<string, object> _parameters;
            private L.Expression _result;
            private readonly L.Expression _context;
            private readonly EvaluateOptions _options = EvaluateOptions.None;
            private readonly Dictionary<Type, HashSet<Type>> _implicitPrimitiveConversionTable = new Dictionary<Type, HashSet<Type>>() {
                { typeof(sbyte), new HashSet<Type> { typeof(short), typeof(int), typeof(long), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(byte), new HashSet<Type> { typeof(short), typeof(ushort), typeof(int), typeof(uint), typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(short), new HashSet<Type> { typeof(int), typeof(long), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(ushort), new HashSet<Type> { typeof(int), typeof(uint), typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(int), new HashSet<Type> { typeof(long), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(uint), new HashSet<Type> { typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(long), new HashSet<Type> { typeof(float), typeof(double), typeof(decimal) }},
                { typeof(char), new HashSet<Type> { typeof(ushort), typeof(int), typeof(uint), typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(float), new HashSet<Type> { typeof(double) }},
                { typeof(ulong), new HashSet<Type> { typeof(float), typeof(double), typeof(decimal) }},
            };

            private bool Ordinal { get { return (_options & EvaluateOptions.MatchStringsOrdinal) == EvaluateOptions.MatchStringsOrdinal; } }
            private bool IgnoreCaseString { get { return (_options & EvaluateOptions.MatchStringsWithIgnoreCase) == EvaluateOptions.MatchStringsWithIgnoreCase; } }
            private bool Checked { get { return (_options & EvaluateOptions.OverflowProtection) == EvaluateOptions.OverflowProtection; } }
            private bool CSharpNull { get { return true; /* TODO: should we use null like C#? */ } }

            public HackExpressionVistor(IDictionary<string, object> parameters, EvaluateOptions options)
            {
                _parameters = parameters;
                _options = options;
            }

            public HackExpressionVistor(L.ParameterExpression context, EvaluateOptions options)
            {
                _context = context;
                _options = options;
            }

            public L.Expression Result => _result;

            public override void Visit(LogicalExpression expression)
            {
                throw new NotImplementedException();
            }

            public override void Visit(TernaryExpression expression)
            {
                expression.LeftExpression.Accept(this);
                var test = _result;

                expression.MiddleExpression.Accept(this);
                var ifTrue = _result;

                expression.RightExpression.Accept(this);
                var ifFalse = _result;

                _result = L.Expression.Condition(test, ifTrue, ifFalse);
            }

            public override void Visit(BinaryExpression expression)
            {
                expression.LeftExpression.Accept(this);
                var left = _result;

                expression.RightExpression.Accept(this);
                var right = _result;

                switch (expression.Type)
                {
                    case BinaryExpressionType.And:
                        _result = L.Expression.AndAlso(left, right);
                        break;
                    case BinaryExpressionType.Or:
                        _result = L.Expression.OrElse(left, right);
                        break;
                    case BinaryExpressionType.NotEqual:
                        _result = WithCommonNumericType(left, right, L.Expression.NotEqual, expression.Type);
                        break;
                    case BinaryExpressionType.LesserOrEqual:
                        _result = WithCommonNumericType(left, right, L.Expression.LessThanOrEqual, expression.Type);
                        break;
                    case BinaryExpressionType.GreaterOrEqual:
                        _result = WithCommonNumericType(left, right, L.Expression.GreaterThanOrEqual, expression.Type);
                        break;
                    case BinaryExpressionType.Lesser:
                        _result = WithCommonNumericType(left, right, L.Expression.LessThan, expression.Type);
                        break;
                    case BinaryExpressionType.Greater:
                        _result = WithCommonNumericType(left, right, L.Expression.GreaterThan, expression.Type);
                        break;
                    case BinaryExpressionType.Equal:
                        _result = WithCommonNumericType(left, right, L.Expression.Equal, expression.Type);
                        break;
                    case BinaryExpressionType.Minus:
                        if (Checked) _result = WithCommonNumericType(left, right, L.Expression.SubtractChecked);
                        else _result = WithCommonNumericType(left, right, L.Expression.Subtract);
                        break;
                    case BinaryExpressionType.Plus:
                        if (Checked) _result = WithCommonNumericType(left, right, L.Expression.AddChecked);
                        else _result = WithCommonNumericType(left, right, L.Expression.Add);
                        break;
                    case BinaryExpressionType.Modulo:
                        _result = WithCommonNumericType(left, right, L.Expression.Modulo);
                        break;
                    case BinaryExpressionType.Div:
                        _result = WithCommonNumericType(left, right, L.Expression.Divide);
                        break;
                    case BinaryExpressionType.Times:
                        if (Checked) _result = WithCommonNumericType(left, right, L.Expression.MultiplyChecked);
                        else _result = WithCommonNumericType(left, right, L.Expression.Multiply);
                        break;
                    case BinaryExpressionType.BitwiseOr:
                        _result = L.Expression.Or(left, right);
                        break;
                    case BinaryExpressionType.BitwiseAnd:
                        _result = L.Expression.And(left, right);
                        break;
                    case BinaryExpressionType.BitwiseXOr:
                        _result = L.Expression.ExclusiveOr(left, right);
                        break;
                    case BinaryExpressionType.LeftShift:
                        _result = L.Expression.LeftShift(left, right);
                        break;
                    case BinaryExpressionType.RightShift:
                        _result = L.Expression.RightShift(left, right);
                        break;
                    default:
                        throw new ArgumentOutOfRangeException();
                }
            }

            public override void Visit(UnaryExpression expression)
            {
                expression.Expression.Accept(this);
                switch (expression.Type)
                {
                    case UnaryExpressionType.Not:
                        _result = L.Expression.Not(_result);
                        break;
                    case UnaryExpressionType.Negate:
                        _result = L.Expression.Negate(_result);
                        break;
                    case UnaryExpressionType.BitwiseNot:
                        _result = L.Expression.Not(_result);
                        break;
                    default:
                        throw new ArgumentOutOfRangeException();
                }
            }

            public override void Visit(ValueExpression expression)
            {
                _result = L.Expression.Constant(expression.Value);
            }

            public override void Visit(Function function)
            {
                var args = new L.Expression[function.Expressions.Length];
                for (int i = 0; i < function.Expressions.Length; i++)
                {
                    function.Expressions[i].Accept(this);
                    args[i] = _result;
                }

                string functionName = function.Identifier.Name.ToLowerInvariant();
                if (functionName == "if")
                {
                    var numberTypePriority = new Type[] { typeof(double), typeof(float), typeof(long), typeof(int), typeof(short) };
                    var index1 = Array.IndexOf(numberTypePriority, args[1].Type);
                    var index2 = Array.IndexOf(numberTypePriority, args[2].Type);
                    if (index1 >= 0 && index2 >= 0 && index1 != index2)
                    {
                        args[1] = L.Expression.Convert(args[1], numberTypePriority[Math.Min(index1, index2)]);
                        args[2] = L.Expression.Convert(args[2], numberTypePriority[Math.Min(index1, index2)]);
                    }
                    _result = L.Expression.Condition(args[0], args[1], args[2]);
                    return;
                }
                else if (functionName == "in")
                {
                    var items = L.Expression.NewArrayInit(args[0].Type,
                            new ArraySegment<L.Expression>(args, 1, args.Length - 1));
                    var smi = typeof(Array).GetRuntimeMethod("IndexOf", new[] { typeof(Array), typeof(object) });
                    var r = L.Expression.Call(smi, L.Expression.Convert(items, typeof(Array)), L.Expression.Convert(args[0], typeof(object)));
                    _result = L.Expression.GreaterThanOrEqual(r, L.Expression.Constant(0));
                    return;
                }

                //Context methods take precedence over built-in functions because they're user-customisable.
                var mi = FindMethod(function.Identifier.Name, args);
                if (mi != null)
                {
                    _result = L.Expression.Call(_context, mi.BaseMethodInfo, mi.PreparedArguments);
                    return;
                }

                switch (functionName)
                {
                    case "min":
                        var minArg0 = L.Expression.Convert(args[0], typeof(double));
                        var minArg1 = L.Expression.Convert(args[1], typeof(double));
                        _result = L.Expression.Condition(L.Expression.LessThan(minArg0, minArg1), minArg0, minArg1);
                        break;
                    case "max":
                        var maxArg0 = L.Expression.Convert(args[0], typeof(double));
                        var maxArg1 = L.Expression.Convert(args[1], typeof(double));
                        _result = L.Expression.Condition(L.Expression.GreaterThan(maxArg0, maxArg1), maxArg0, maxArg1);
                        break;
                    case "pow":
                        var powArg0 = L.Expression.Convert(args[0], typeof(double));
                        var powArg1 = L.Expression.Convert(args[1], typeof(double));
                        _result = L.Expression.Power(powArg0, powArg1);
                        break;
                    default:
                        throw new MissingMethodException($"method not found: {functionName}");
                }
            }

            public override void Visit(Identifier function)
            {
                if (_context == null)
                {
                    _result = L.Expression.Constant(_parameters[function.Name]);
                }
                else
                {
                    _result = L.Expression.PropertyOrField(_context, function.Name);
                }
            }

            private ExtendedMethodInfo FindMethod(string methodName, L.Expression[] methodArgs)
            {
                if (_context == null) return null;

                TypeInfo contextTypeInfo = _context.Type.GetTypeInfo();
                TypeInfo objectTypeInfo = typeof(object).GetTypeInfo();
                do
                {
                    var methods = contextTypeInfo.DeclaredMethods.Where(m => m.Name.Equals(methodName, StringComparison.OrdinalIgnoreCase) && m.IsPublic && !m.IsStatic);
                    var candidates = new List<ExtendedMethodInfo>();
                    foreach (var potentialMethod in methods)
                    {
                        var methodParams = potentialMethod.GetParameters();
                        var preparedArguments = PrepareMethodArgumentsIfValid(methodParams, methodArgs);

                        if (preparedArguments != null)
                        {
                            var candidate = new ExtendedMethodInfo()
                            {
                                BaseMethodInfo = potentialMethod,
                                PreparedArguments = preparedArguments.Item2,
                                Score = preparedArguments.Item1
                            };
                            if (candidate.Score == 0) return candidate;
                            candidates.Add(candidate);
                        }
                    }
                    if (candidates.Any()) return candidates.OrderBy(method => method.Score).First();
                    contextTypeInfo = contextTypeInfo.BaseType.GetTypeInfo();
                } while (contextTypeInfo != objectTypeInfo);
                return null;
            }

            /// <summary>
            /// Returns a tuple where the first item is a score, and the second is a list of prepared arguments. 
            /// Score is a simplified indicator of how close the arguments' types are to the parameters'. A score of 0 indicates a perfect match between arguments and parameters. 
            /// Prepared arguments refers to having the arguments implicitly converted where necessary, and "params" arguments collated into one array.
            /// </summary>
            /// <param name="parameters"></param>
            /// <param name="arguments"></param>
            /// <returns></returns>
            private Tuple<int, L.Expression[]> PrepareMethodArgumentsIfValid(ParameterInfo[] parameters, L.Expression[] arguments)
            {
                if (!parameters.Any() && !arguments.Any()) return Tuple.Create(0, arguments);
                if (!parameters.Any()) return null;

                var lastParameter = parameters.Last();
                bool hasParamsKeyword = lastParameter.IsDefined(typeof(ParamArrayAttribute));
                if (hasParamsKeyword && parameters.Length > arguments.Length) return null;
                L.Expression[] newArguments = new L.Expression[parameters.Length];
                L.Expression[] paramsKeywordArgument = null;
                Type paramsElementType = null;
                int paramsParameterPosition = 0;
                if (!hasParamsKeyword)
                {
                    if (parameters.Length != arguments.Length) return null;
                }
                else
                {
                    paramsParameterPosition = lastParameter.Position;
                    paramsElementType = lastParameter.ParameterType.GetElementType();
                    paramsKeywordArgument = new L.Expression[arguments.Length - parameters.Length + 1];
                }

                int functionMemberScore = 0;
                for (int i = 0; i < arguments.Length; i++)
                {
                    var isParamsElement = hasParamsKeyword && i >= paramsParameterPosition;
                    var argument = arguments[i];
                    var argumentType = argument.Type;
                    var parameterType = isParamsElement ? paramsElementType : parameters[i].ParameterType;
                    if (argumentType != parameterType)
                    {
                        bool canCastImplicitly = TryCastImplicitly(argumentType, parameterType, ref argument);
                        if (!canCastImplicitly) return null;
                        functionMemberScore++;
                    }
                    if (!isParamsElement)
                    {
                        newArguments[i] = argument;
                    }
                    else
                    {
                        paramsKeywordArgument[i - paramsParameterPosition] = argument;
                    }
                }

                if (hasParamsKeyword)
                {
                    newArguments[paramsParameterPosition] = L.Expression.NewArrayInit(paramsElementType, paramsKeywordArgument);
                }
                return Tuple.Create(functionMemberScore, newArguments);
            }

            private bool TryCastImplicitly(Type from, Type to, ref L.Expression argument)
            {
                if (CSharpNull && IsNullable(to))
                {
                    var fromWithoutNullable = GetNullableType(from) ?? from;
                    var toWithoutNullable = GetNullableType(to);
                    if (fromWithoutNullable != toWithoutNullable)
                    {
                        bool convertingFromPrimitiveType = _implicitPrimitiveConversionTable.TryGetValue(fromWithoutNullable, out var possibleConversions);
                        if (!convertingFromPrimitiveType || !possibleConversions.Contains(toWithoutNullable))
                        {
                            argument = null;
                            return false;
                        }
                    }
                }
                else
                {
                    bool convertingFromPrimitiveType = _implicitPrimitiveConversionTable.TryGetValue(from, out var possibleConversions);
                    if (!convertingFromPrimitiveType || !possibleConversions.Contains(to))
                    {
                        argument = null;
                        return false;
                    }
                }
                argument = L.Expression.Convert(argument, to);
                return true;
            }

            private L.Expression WithCommonNumericType(L.Expression left, L.Expression right,
                Func<L.Expression, L.Expression, L.Expression> action, BinaryExpressionType expressiontype = BinaryExpressionType.Unknown)
            {
                action = UnwrapNullableAction(left, right, action, expressiontype);
                left = UnwrapNullable(left);
                right = UnwrapNullable(right);

                if (_options.HasFlag(EvaluateOptions.BooleanCalculation))
                {
                    if (left.Type == typeof(bool))
                    {
                        left = L.Expression.Condition(left, L.Expression.Constant(1.0), L.Expression.Constant(0.0));
                    }

                    if (right.Type == typeof(bool))
                    {
                        right = L.Expression.Condition(right, L.Expression.Constant(1.0), L.Expression.Constant(0.0));
                    }
                }

                var precedence = new[]
                {
                    typeof(decimal),
                    typeof(double),
                    typeof(float),
                    typeof(ulong),
                    typeof(long),
                    typeof(uint),
                    typeof(int),
                    typeof(ushort),
                    typeof(short),
                    typeof(byte),
                    typeof(sbyte)
                };

                int l = Array.IndexOf(precedence, left.Type);
                int r = Array.IndexOf(precedence, right.Type);
                if (l >= 0 && r >= 0)
                {
                    var type = precedence[Math.Min(l, r)];
                    if (left.Type != type)
                    {
                        left = L.Expression.Convert(left, type);
                    }

                    if (right.Type != type)
                    {
                        right = L.Expression.Convert(right, type);
                    }
                }
                L.Expression comparer = null;
                if (IgnoreCaseString)
                {
                    if (Ordinal) comparer = L.Expression.Property(null, typeof(StringComparer), "OrdinalIgnoreCase");
                    else comparer = L.Expression.Property(null, typeof(StringComparer), "CurrentCultureIgnoreCase");
                }
                else comparer = L.Expression.Property(null, typeof(StringComparer), "Ordinal");

                if (comparer != null && (typeof(string).Equals(left.Type) || typeof(string).Equals(right.Type)))
                {
                    switch (expressiontype)
                    {
                        case BinaryExpressionType.Equal: return L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Equals", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right });
                        case BinaryExpressionType.NotEqual: return L.Expression.Not(L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Equals", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right }));
                        case BinaryExpressionType.GreaterOrEqual: return L.Expression.GreaterThanOrEqual(L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Compare", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right }), L.Expression.Constant(0));
                        case BinaryExpressionType.LesserOrEqual: return L.Expression.LessThanOrEqual(L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Compare", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right }), L.Expression.Constant(0));
                        case BinaryExpressionType.Greater: return L.Expression.GreaterThan(L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Compare", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right }), L.Expression.Constant(0));
                        case BinaryExpressionType.Lesser: return L.Expression.LessThan(L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Compare", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right }), L.Expression.Constant(0));
                    }
                }
                return action(left, right);
            }

            private bool IsNullable(Type type)
            {
                var ti = type.GetTypeInfo();
                return ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(Nullable<>);
            }

            private bool IsNullable(L.Expression expression)
            {
                return IsNullable(expression.Type);
            }

            private Type GetNullableType(Type type)
            {
                return IsNullable(type) ? type.GetTypeInfo().GenericTypeArguments[0] : null;
            }

            private Type GetNullableType(L.Expression expression)
            {
                return GetNullableType(expression.Type);
            }

            private L.Expression HaveValueExpression(L.Expression expression)
            {
                if (IsNullable(expression))
                {
                    return L.Expression.Property(expression, "HasValue");
                }
                else
                {
                    return L.Expression.Constant(true);
                }
            }

            private Func<L.Expression, L.Expression, L.Expression> UnwrapNullableAction(L.Expression left, L.Expression rigth, Func<L.Expression, L.Expression, L.Expression> action, BinaryExpressionType expressiontype)
            {
                if (CSharpNull && (IsNullable(left) || IsNullable(rigth)))
                {
                    bool castToNullable = false;
                    L.Expression nullMismatch;
                    switch (expressiontype)
                    {
                        case BinaryExpressionType.Equal:
                            nullMismatch = L.Expression.Equal(HaveValueExpression(left), HaveValueExpression(rigth));
                            break;

                        case BinaryExpressionType.NotEqual:
                            nullMismatch = L.Expression.NotEqual(HaveValueExpression(left), HaveValueExpression(rigth));
                            break;

                        case BinaryExpressionType.Lesser:
                        case BinaryExpressionType.LesserOrEqual:
                        case BinaryExpressionType.Greater:
                        case BinaryExpressionType.GreaterOrEqual:
                            nullMismatch = L.Expression.Constant(false);
                            break;

                        default:
                            castToNullable = true;
                            nullMismatch = L.Expression.Constant(null);
                            break;
                    }

                    return (l, r) =>
                    {
                        var act = action(l, r);
                        var nullableType = typeof(Nullable<>).MakeGenericType(act.Type);
                        return L.Expression.Condition(
                            L.Expression.And(HaveValueExpression(left), HaveValueExpression(rigth)),
                            castToNullable ? L.Expression.Convert(act, nullableType) : act,
                            castToNullable ? L.Expression.Convert(nullMismatch, nullableType) : nullMismatch);
                    };
                }
                else
                {
                    return action;
                }
            }

            private L.Expression UnwrapNullable(L.Expression expression)
            {
                if (IsNullable(expression))
                {
                    return L.Expression.Condition(
                        HaveValueExpression(expression),
                        L.Expression.Property(expression, "Value"),
                        L.Expression.Default(GetNullableType(expression)));
                }

                return expression;
            }
        }

Its not perfect solution, as I could change more code to have simpler result expression tree but I do not need 101% performace from this code only that all automatic tests pass.

@sklose
Copy link
Owner

sklose commented Oct 28, 2022

Thanks for raising this issue. If you can turn your code into a PR I am happy to merge this change. I agree it should be behind an op-in flag to not break existing users.

@Yankes
Copy link
Author

Yankes commented Oct 28, 2022

Ok, probaby in next week I will have some free time to prepare PR for this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants