From 1f2b2e6c79f4a09ddaeb29b17748c8bc0fa4a91c Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Mon, 22 Jul 2019 22:19:15 +0200 Subject: [PATCH] Minor code cleanup in SqlTranslatingEV --- ...lationalSqlTranslatingExpressionVisitor.cs | 55 +++++++++---------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index f70fa4b44e2..dcda4ee947e 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -173,9 +173,9 @@ protected override Expression VisitMember(MemberExpression memberExpression) return result; } - return TranslationFailed(memberExpression.Expression, innerExpression) + return TranslationFailed(memberExpression.Expression, innerExpression, out var sqlInnerExpression) ? null - : Dependencies.MemberTranslatorProvider.Translate((SqlExpression)innerExpression, memberExpression.Member, memberExpression.Type); + : Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type); } private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression) @@ -329,8 +329,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } // MethodCall translators - var @object = Visit(methodCallExpression.Object); - if (TranslationFailed(methodCallExpression.Object, @object)) + if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject)) { return null; } @@ -338,15 +337,15 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp var arguments = new SqlExpression[methodCallExpression.Arguments.Count]; for (var i = 0; i < arguments.Length; i++) { - var argument = Visit(methodCallExpression.Arguments[i]); - if (TranslationFailed(methodCallExpression.Arguments[i], argument)) + var argument = methodCallExpression.Arguments[i]; + if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) { return null; } - arguments[i] = (SqlExpression)argument; + arguments[i] = sqlArgument; } - return Dependencies.MethodCallTranslatorProvider.Translate(_model, (SqlExpression)@object, methodCallExpression.Method, arguments); + return Dependencies.MethodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); } private static Expression TryRemoveImplicitConvert(Expression expression) @@ -427,19 +426,16 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) var left = TryRemoveImplicitConvert(binaryExpression.Left); var right = TryRemoveImplicitConvert(binaryExpression.Right); - left = Visit(left); - right = Visit(right); - - if (TranslationFailed(binaryExpression.Left, left) - || TranslationFailed(binaryExpression.Right, right)) + if (TranslationFailed(binaryExpression.Left, Visit(left), out var sqlLeft) + || TranslationFailed(binaryExpression.Right, Visit(right), out var sqlRight)) { return null; } return _sqlExpressionFactory.MakeBinary( binaryExpression.NodeType, - (SqlExpression)left, - (SqlExpression)right, + sqlLeft, + sqlRight, null); } @@ -493,19 +489,14 @@ protected override Expression VisitConditional(ConditionalExpression conditional var ifTrue = Visit(conditionalExpression.IfTrue); var ifFalse = Visit(conditionalExpression.IfFalse); - if (TranslationFailed(conditionalExpression.Test, test) - || TranslationFailed(conditionalExpression.IfTrue, ifTrue) - || TranslationFailed(conditionalExpression.IfFalse, ifFalse)) + if (TranslationFailed(conditionalExpression.Test, test, out var sqlTest) + || TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue) + || TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse)) { return null; } - return _sqlExpressionFactory.Case( - new[] - { - new CaseWhenClause((SqlExpression)test,(SqlExpression) ifTrue) - }, - (SqlExpression)ifFalse); + return _sqlExpressionFactory.Case(new[] { new CaseWhenClause(sqlTest, sqlIfTrue) }, sqlIfFalse); } protected override Expression VisitUnary(UnaryExpression unaryExpression) @@ -517,12 +508,11 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return unaryExpression.Update(operand); } - if (TranslationFailed(unaryExpression.Operand, operand)) + if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand)) { return null; } - var sqlOperand = (SqlExpression)operand; switch (unaryExpression.NodeType) { case ExpressionType.Not: @@ -557,7 +547,16 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) } [DebuggerStepThrough] - private bool TranslationFailed(Expression original, Expression translation) - => original != null && !(translation is SqlExpression); + private bool TranslationFailed(Expression original, Expression translation, out SqlExpression castTranslation) + { + if (original != null && !(translation is SqlExpression)) + { + castTranslation = null; + return true; + } + + castTranslation = translation as SqlExpression; + return false; + } } }