diff --git a/diktat-rules/src/main/kotlin/org/cqfn/diktat/ruleset/rules/calculations/AccurateCalculationsRule.kt b/diktat-rules/src/main/kotlin/org/cqfn/diktat/ruleset/rules/calculations/AccurateCalculationsRule.kt index 1df1701ee7..d810f293e0 100644 --- a/diktat-rules/src/main/kotlin/org/cqfn/diktat/ruleset/rules/calculations/AccurateCalculationsRule.kt +++ b/diktat-rules/src/main/kotlin/org/cqfn/diktat/ruleset/rules/calculations/AccurateCalculationsRule.kt @@ -5,6 +5,7 @@ import com.pinterest.ktlint.core.ast.ElementType import org.cqfn.diktat.common.config.rules.RulesConfig import org.cqfn.diktat.ruleset.constants.Warnings.FLOAT_IN_ACCURATE_CALCULATIONS import org.cqfn.diktat.ruleset.utils.findLocalDeclaration +import org.cqfn.diktat.ruleset.utils.getFunctionName import org.jetbrains.kotlin.com.intellij.lang.ASTNode import org.jetbrains.kotlin.com.intellij.psi.PsiElement import org.jetbrains.kotlin.lexer.KtTokens @@ -13,11 +14,13 @@ import org.jetbrains.kotlin.psi.KtCallExpression import org.jetbrains.kotlin.psi.KtDotQualifiedExpression import org.jetbrains.kotlin.psi.KtExpression import org.jetbrains.kotlin.psi.KtNameReferenceExpression +import org.jetbrains.kotlin.psi.psiUtil.parentsWithSelf import org.jetbrains.kotlin.psi.psiUtil.startOffset /** * Rule that checks that floating-point numbers are not used for accurate calculations * 1. Checks that floating-point numbers are not used in arithmetic binary expressions + * Exception: allows arithmetic operations only when absolute value of result is immediately used in comparison * Fixme: detect variables by type, not only floating-point literals */ class AccurateCalculationsRule(private val configRules: List) : Rule("accurate-calculations") { @@ -32,7 +35,9 @@ class AccurateCalculationsRule(private val configRules: List) : Rul KtTokens.GT, KtTokens.LT, KtTokens.LTEQ, KtTokens.GTEQ, KtTokens.EQEQ ) + private val comparisonOperators = listOf(KtTokens.LT, KtTokens.LTEQ, KtTokens.GT, KtTokens.GTEQ) private val arithmeticOperationsFunctions = listOf("equals", "compareTo") + private val comparisonFunctions = listOf("compareTo") } override fun visit(node: ASTNode, @@ -51,20 +56,20 @@ class AccurateCalculationsRule(private val configRules: List) : Rul @Suppress("UnsafeCallOnNullableType") private fun handleBinaryExpression(expression: KtBinaryExpression) = expression .takeIf { it.operationToken in arithmeticOperationTokens } - ?.let { expr -> + ?.takeUnless { it.parentsWithSelf.any(::isComparisonWithAbs) } + ?.run { // !! is safe because `KtBinaryExpression#left` is annotated `Nullable IfNotParsed` - val floatValue = expr.left!!.takeIf { it.isFloatingPoint() } - ?: expr.right!!.takeIf { it.isFloatingPoint() } - checkFloatValue(floatValue, expr) + val floatValue = left!!.takeIf { it.isFloatingPoint() } + ?: right!!.takeIf { it.isFloatingPoint() } + checkFloatValue(floatValue, this) + println() } private fun handleFunction(expression: KtDotQualifiedExpression) = expression .takeIf { it.selectorExpression is KtCallExpression } ?.run { receiverExpression to selectorExpression as KtCallExpression } - ?.takeIf { - (it.second.calleeExpression as? KtNameReferenceExpression) - ?.getReferencedName() in arithmeticOperationsFunctions - } + ?.takeIf { it.second.getFunctionName() in arithmeticOperationsFunctions } + ?.takeUnless { expression.parentsWithSelf.any(::isComparisonWithAbs) } ?.let { (receiverExpression, selectorExpression) -> val floatValue = receiverExpression.takeIf { it.isFloatingPoint() } ?: selectorExpression @@ -81,9 +86,57 @@ class AccurateCalculationsRule(private val configRules: List) : Rul "float value of <${floatValue.text}> used in arithmetic expression in ${expression.text}", expression.startOffset, expression.node) } } + + private fun isComparisonWithAbs(psiElement: PsiElement) = + when (psiElement) { + is KtBinaryExpression -> psiElement.isComparisonWithAbs() + is KtDotQualifiedExpression -> psiElement.isComparisonWithAbs() + else -> false + } + + private fun KtBinaryExpression.isComparisonWithAbs() = + takeIf { it.operationToken in comparisonOperators } + ?.run { left as? KtCallExpression ?: right as? KtCallExpression } + ?.run { calleeExpression as? KtNameReferenceExpression } + ?.getReferencedName() + ?.equals("abs") + ?: false + + private fun KtDotQualifiedExpression.isComparisonWithAbs() = + takeIf { + it.selectorExpression.run { + this is KtCallExpression && getFunctionName() in comparisonFunctions + } + } + ?.run { + (selectorExpression as KtCallExpression) + .valueArguments + .singleOrNull() + ?.let { it.getArgumentExpression() as? KtCallExpression } + ?.isAbsOfFloat() + ?: false || + (receiverExpression as? KtCallExpression).isAbsOfFloat() + } + ?: false + + private fun KtCallExpression?.isAbsOfFloat() = this + ?.run { + (calleeExpression as? KtNameReferenceExpression) + ?.getReferencedName() + ?.equals("abs") + ?.and( + valueArguments + .singleOrNull() + ?.getArgumentExpression() + ?.isFloatingPoint() + ?: false) + ?: false + } + ?: false } -private fun PsiElement.isFloatingPoint() = +@Suppress("UnsafeCallOnNullableType") +private fun PsiElement.isFloatingPoint(): Boolean = node.elementType == ElementType.FLOAT_LITERAL || node.elementType == ElementType.FLOAT_CONSTANT || ((this as? KtNameReferenceExpression) @@ -91,4 +144,7 @@ private fun PsiElement.isFloatingPoint() = ?.initializer ?.node ?.run { elementType == ElementType.FLOAT_LITERAL || elementType == ElementType.FLOAT_CONSTANT } + ?: false) || + ((this as? KtBinaryExpression) + ?.run { left!!.isFloatingPoint() && right!!.isFloatingPoint() } ?: false) diff --git a/diktat-rules/src/main/kotlin/org/cqfn/diktat/ruleset/utils/PsiUtils.kt b/diktat-rules/src/main/kotlin/org/cqfn/diktat/ruleset/utils/PsiUtils.kt index afc3a5a118..bc9ea43406 100644 --- a/diktat-rules/src/main/kotlin/org/cqfn/diktat/ruleset/utils/PsiUtils.kt +++ b/diktat-rules/src/main/kotlin/org/cqfn/diktat/ruleset/utils/PsiUtils.kt @@ -87,3 +87,5 @@ fun KtNameReferenceExpression.findLocalDeclaration(): KtProperty? = parents } } .firstOrNull() + +fun KtCallExpression.getFunctionName() = (calleeExpression as? KtNameReferenceExpression)?.getReferencedName() diff --git a/diktat-rules/src/test/kotlin/org/cqfn/diktat/ruleset/chapter4/AccurateCalculationsWarnTest.kt b/diktat-rules/src/test/kotlin/org/cqfn/diktat/ruleset/chapter4/AccurateCalculationsWarnTest.kt index 1fc95a96ba..964e264207 100644 --- a/diktat-rules/src/test/kotlin/org/cqfn/diktat/ruleset/chapter4/AccurateCalculationsWarnTest.kt +++ b/diktat-rules/src/test/kotlin/org/cqfn/diktat/ruleset/chapter4/AccurateCalculationsWarnTest.kt @@ -150,4 +150,37 @@ class AccurateCalculationsWarnTest : LintTestBase(::AccurateCalculationsRule) { LintError(16, 9, ruleId, warnText("x", "x %= 2"), false) ) } + + @Test + @Tag(WarningNames.FLOAT_IN_ACCURATE_CALCULATIONS) + fun `should allow arithmetic operations inside abs in comparison`() { + lintMethod( + """ + |import kotlin.math.abs + | + |fun foo() { + | if (abs(1.0 - 0.999) < 1e-6) { + | println("Comparison with tolerance") + | } + | + | 1e-6 > abs(1.0 - 0.999) + | abs(1.0 - 0.999).compareTo(1e-6) < 0 + | 1e-6.compareTo(abs(1.0 - 0.999)) < 0 + | abs(1.0 - 0.999) == 1e-6 + | + | abs(1.0 - 0.999) < eps + | eps > abs(1.0 - 0.999) + | + | val x = 1.0 + | val y = 0.999 + | abs(x - y) < eps + | eps > abs(x - y) + | abs(1.0 - 0.999) == eps + |} + """.trimMargin(), + LintError(11, 5, ruleId, warnText("1e-6", "abs(1.0 - 0.999) == 1e-6"), false), + LintError(11, 9, ruleId, warnText("1.0", "1.0 - 0.999"), false), + LintError(20, 9, ruleId, warnText("1.0", "1.0 - 0.999"), false) + ) + } }