Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rule 4.1.2: Numbers of a float type should not be directly compared #323

Merged
merged 20 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.pinterest.ktlint.core.ast.ElementType
import org.cqfn.diktat.common.config.rules.RulesConfig
import org.cqfn.diktat.ruleset.constants.Warnings.FLOAT_IN_ACCURATE_CALCULATIONS
import org.cqfn.diktat.ruleset.utils.findLocalDeclaration
import org.cqfn.diktat.ruleset.utils.getFunctionName
import org.jetbrains.kotlin.com.intellij.lang.ASTNode
import org.jetbrains.kotlin.com.intellij.psi.PsiElement
import org.jetbrains.kotlin.lexer.KtTokens
Expand All @@ -13,11 +14,13 @@ import org.jetbrains.kotlin.psi.KtCallExpression
import org.jetbrains.kotlin.psi.KtDotQualifiedExpression
import org.jetbrains.kotlin.psi.KtExpression
import org.jetbrains.kotlin.psi.KtNameReferenceExpression
import org.jetbrains.kotlin.psi.psiUtil.parentsWithSelf
import org.jetbrains.kotlin.psi.psiUtil.startOffset

/**
* Rule that checks that floating-point numbers are not used for accurate calculations
* 1. Checks that floating-point numbers are not used in arithmetic binary expressions
* Exception: allows arithmetic operations only when absolute value of result is immediately used in comparison
* Fixme: detect variables by type, not only floating-point literals
*/
class AccurateCalculationsRule(private val configRules: List<RulesConfig>) : Rule("accurate-calculations") {
Expand All @@ -32,7 +35,9 @@ class AccurateCalculationsRule(private val configRules: List<RulesConfig>) : Rul
KtTokens.GT, KtTokens.LT, KtTokens.LTEQ, KtTokens.GTEQ,
KtTokens.EQEQ
)
private val comparisonOperators = listOf(KtTokens.LT, KtTokens.LTEQ, KtTokens.GT, KtTokens.GTEQ)
private val arithmeticOperationsFunctions = listOf("equals", "compareTo")
private val comparisonFunctions = listOf("compareTo")
}

override fun visit(node: ASTNode,
Expand All @@ -51,20 +56,20 @@ class AccurateCalculationsRule(private val configRules: List<RulesConfig>) : Rul
@Suppress("UnsafeCallOnNullableType")
private fun handleBinaryExpression(expression: KtBinaryExpression) = expression
.takeIf { it.operationToken in arithmeticOperationTokens }
?.let { expr ->
?.takeUnless { it.parentsWithSelf.any(::isComparisonWithAbs) }
?.run {
// !! is safe because `KtBinaryExpression#left` is annotated `Nullable IfNotParsed`
val floatValue = expr.left!!.takeIf { it.isFloatingPoint() }
?: expr.right!!.takeIf { it.isFloatingPoint() }
checkFloatValue(floatValue, expr)
val floatValue = left!!.takeIf { it.isFloatingPoint() }
?: right!!.takeIf { it.isFloatingPoint() }
checkFloatValue(floatValue, this)
println()
}

private fun handleFunction(expression: KtDotQualifiedExpression) = expression
.takeIf { it.selectorExpression is KtCallExpression }
?.run { receiverExpression to selectorExpression as KtCallExpression }
?.takeIf {
(it.second.calleeExpression as? KtNameReferenceExpression)
?.getReferencedName() in arithmeticOperationsFunctions
}
?.takeIf { it.second.getFunctionName() in arithmeticOperationsFunctions }
?.takeUnless { expression.parentsWithSelf.any(::isComparisonWithAbs) }
?.let { (receiverExpression, selectorExpression) ->
val floatValue = receiverExpression.takeIf { it.isFloatingPoint() }
?: selectorExpression
Expand All @@ -81,14 +86,65 @@ class AccurateCalculationsRule(private val configRules: List<RulesConfig>) : Rul
"float value of <${floatValue.text}> used in arithmetic expression in ${expression.text}", expression.startOffset, expression.node)
}
}

private fun isComparisonWithAbs(psiElement: PsiElement) =
when (psiElement) {
is KtBinaryExpression -> psiElement.isComparisonWithAbs()
is KtDotQualifiedExpression -> psiElement.isComparisonWithAbs()
else -> false
}

private fun KtBinaryExpression.isComparisonWithAbs() =
takeIf { it.operationToken in comparisonOperators }
?.run { left as? KtCallExpression ?: right as? KtCallExpression }
?.run { calleeExpression as? KtNameReferenceExpression }
?.getReferencedName()
?.equals("abs")
?: false

private fun KtDotQualifiedExpression.isComparisonWithAbs() =
takeIf {
it.selectorExpression.run {
this is KtCallExpression && getFunctionName() in comparisonFunctions
}
}
?.run {
(selectorExpression as KtCallExpression)
.valueArguments
.singleOrNull()
?.let { it.getArgumentExpression() as? KtCallExpression }
?.isAbsOfFloat()
?: false ||
(receiverExpression as? KtCallExpression).isAbsOfFloat()
}
?: false

private fun KtCallExpression?.isAbsOfFloat() = this
?.run {
(calleeExpression as? KtNameReferenceExpression)
?.getReferencedName()
?.equals("abs")
?.and(
valueArguments
.singleOrNull()
?.getArgumentExpression()
?.isFloatingPoint()
?: false)
?: false
}
?: false
}

private fun PsiElement.isFloatingPoint() =
@Suppress("UnsafeCallOnNullableType")
private fun PsiElement.isFloatingPoint(): Boolean =
node.elementType == ElementType.FLOAT_LITERAL ||
node.elementType == ElementType.FLOAT_CONSTANT ||
((this as? KtNameReferenceExpression)
?.findLocalDeclaration()
?.initializer
?.node
?.run { elementType == ElementType.FLOAT_LITERAL || elementType == ElementType.FLOAT_CONSTANT }
?: false) ||
((this as? KtBinaryExpression)
?.run { left!!.isFloatingPoint() && right!!.isFloatingPoint() }
?: false)
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,5 @@ fun KtNameReferenceExpression.findLocalDeclaration(): KtProperty? = parents
}
}
.firstOrNull()

fun KtCallExpression.getFunctionName() = (calleeExpression as? KtNameReferenceExpression)?.getReferencedName()
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,37 @@ class AccurateCalculationsWarnTest : LintTestBase(::AccurateCalculationsRule) {
LintError(16, 9, ruleId, warnText("x", "x %= 2"), false)
)
}

@Test
@Tag(WarningNames.FLOAT_IN_ACCURATE_CALCULATIONS)
fun `should allow arithmetic operations inside abs in comparison`() {
lintMethod(
"""
|import kotlin.math.abs
|
petertrr marked this conversation as resolved.
Show resolved Hide resolved
|fun foo() {
| if (abs(1.0 - 0.999) < 1e-6) {
| println("Comparison with tolerance")
| }
|
| 1e-6 > abs(1.0 - 0.999)
| abs(1.0 - 0.999).compareTo(1e-6) < 0
| 1e-6.compareTo(abs(1.0 - 0.999)) < 0
| abs(1.0 - 0.999) == 1e-6
|
| abs(1.0 - 0.999) < eps
| eps > abs(1.0 - 0.999)
|
| val x = 1.0
| val y = 0.999
| abs(x - y) < eps
| eps > abs(x - y)
| abs(1.0 - 0.999) == eps
|}
""".trimMargin(),
LintError(11, 5, ruleId, warnText("1e-6", "abs(1.0 - 0.999) == 1e-6"), false),
LintError(11, 9, ruleId, warnText("1.0", "1.0 - 0.999"), false),
LintError(20, 9, ruleId, warnText("1.0", "1.0 - 0.999"), false)
)
}
}