Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix correctness issue with CASE WHEN with expressions that have side-effects #4383

Merged
merged 22 commits into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
aa6db16
Fix correctness issue with CASE WHEN with expressions that have side-…
andygrove Dec 17, 2021
e0879f9
code cleanup and comments
andygrove Dec 17, 2021
5756850
Revert unnecessary change
andygrove Dec 17, 2021
f509f69
Revert unnecessary change
andygrove Dec 17, 2021
58bfdad
Add license header
andygrove Dec 20, 2021
5b5247d
Add more comments. Add optimization to stop processing branches once …
andygrove Dec 21, 2021
e9e5f50
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalE…
andygrove Dec 22, 2021
58d2c26
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalE…
andygrove Dec 22, 2021
2ffd030
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalE…
andygrove Dec 22, 2021
044c98a
remove redundant check
andygrove Dec 22, 2021
e83b187
close thenValues resource earlier
andygrove Dec 22, 2021
76d63a4
close elseValues resource earlier
andygrove Dec 22, 2021
89da423
refactor for readability
andygrove Dec 22, 2021
96fc175
refactor to remove duplicate code
andygrove Dec 22, 2021
4a79414
simplify inverting cumulativePred for else condition
andygrove Jan 4, 2022
aad30f4
Merge remote-tracking branch 'nvidia/branch-22.02' into gpu-case-when…
andygrove Jan 4, 2022
1233870
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalE…
andygrove Jan 5, 2022
bde6b85
fix compilation error
andygrove Jan 5, 2022
034cc01
Use AST NULL_LOGICAL_OR for computing cumulative predicate
andygrove Jan 5, 2022
5ae6922
Update isFirstTrueWhen to be null-safe
andygrove Jan 5, 2022
bb80e29
address feedback
andygrove Jan 5, 2022
98d2207
Add tests for predicates evaluating to null
andygrove Jan 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions integration_tests/src/main/python/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,28 +186,36 @@ def test_ifnull(data_gen):
'ifnull({}, b)'.format(null_lit),
'ifnull(a, {})'.format(null_lit)))

@pytest.mark.parametrize('data_gen', int_n_long_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', [IntegerGen().with_special_case(2147483647)], ids=idfn)
def test_conditional_with_side_effects_col_col(data_gen):
gen = IntegerGen().with_special_case(2147483647)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : two_col_df(spark, data_gen, gen).selectExpr(
'IF(b < 2147483647, b + 1, b)'),
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'IF(a < 2147483647, a + 1, a)'),
conf = {'spark.sql.ansi.enabled':True})

@pytest.mark.parametrize('data_gen', int_n_long_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', [IntegerGen().with_special_case(2147483647)], ids=idfn)
def test_conditional_with_side_effects_col_scalar(data_gen):
gen = IntegerGen().with_special_case(2147483647)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : two_col_df(spark, data_gen, gen).selectExpr(
'IF(b < 2147483647, b + 1, 2147483647)',
'IF(b >= 2147483646, 2147483647, b + 1)'),
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'IF(a < 2147483647, a + 1, 2147483647)',
'IF(a >= 2147483646, 2147483647, a + 1)'),
conf = {'spark.sql.ansi.enabled':True})

@pytest.mark.parametrize('data_gen', int_n_long_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', [mk_str_gen('[0-9]{1,20}')], ids=idfn)
def test_conditional_with_side_effects_cast(data_gen):
gen = mk_str_gen('[0-9]{1,20}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : two_col_df(spark, data_gen, gen).selectExpr(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'IF(a RLIKE "^[0-9]{1,5}$", CAST(a AS INT), 0)'),
conf = {'spark.sql.ansi.enabled':True,
'spark.rapids.sql.expression.RLike': True})
'spark.rapids.sql.expression.RLike': True})

@pytest.mark.parametrize('data_gen', [mk_str_gen('[0-9]{1,9}')], ids=idfn)
def test_conditional_with_side_effects_case_when(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'CASE \
WHEN a RLIKE "^[0-9]{1,3}$" THEN CAST(a AS INT) \
WHEN a RLIKE "^[0-9]{4,6}$" THEN CAST(a AS INT) + 123 \
ELSE -1 END'),
conf = {'spark.sql.ansi.enabled':True,
'spark.rapids.sql.expression.RLike': True})
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,61 @@ trait GpuConditionalExpression extends ComplexTypeMergingExpression with GpuExpr
!anyTrue.getBoolean
}
}

protected def filterBatch(
tbl: Table,
pred: ColumnVector,
colTypes: Array[DataType]): ColumnarBatch = {
withResource(tbl.filter(pred)) { filteredData =>
GpuColumnVector.from(filteredData, colTypes)
}
}

private def boolToInt(cv: ColumnVector): ColumnVector = {
withResource(GpuScalar.from(1, DataTypes.IntegerType)) { one =>
withResource(GpuScalar.from(0, DataTypes.IntegerType)) { zero =>
cv.ifElse(one, zero)
}
}
}

protected def gather(predicate: ColumnVector, t: GpuColumnVector): ColumnVector = {
// convert the predicate boolean column to numeric where 1 = true
// amd 0 = false and then use `scan` with `sum` to convert to
// indices.
//
// For example, if the predicate evaluates to [F, F, T, F, T] then this
jlowe marked this conversation as resolved.
Show resolved Hide resolved
// gets translated first to [0, 0, 1, 0, 1] and then the scan operation
// will perform an exclusive sum on these values and
// produce [0, 0, 0, 1, 1]. Combining this with the original
// predicate boolean array results in the two T values mapping to
// indices 0 and 1, respectively.

withResource(boolToInt(predicate)) { boolsAsInts =>
withResource(boolsAsInts.scan(
ScanAggregation.sum(),
ScanType.EXCLUSIVE,
NullPolicy.INCLUDE)) { prefixSumExclusive =>

// for the entries in the gather map that do not represent valid
// values to be gathered, we change the value to -MAX_INT which
// will be treated as null values in the gather algorithm
val gatherMap = withResource(Scalar.fromInt(Int.MinValue)) {
outOfBoundsFlag => predicate.ifElse(prefixSumExclusive, outOfBoundsFlag)
}

andygrove marked this conversation as resolved.
Show resolved Hide resolved
withResource(new Table(t.getBase)) { tbl =>
withResource(gatherMap) { _ =>
andygrove marked this conversation as resolved.
Show resolved Hide resolved
withResource(tbl.gather(gatherMap)) { gatherTbl =>
gatherTbl.getColumn(0).incRefCount()
}
}
}
}
}
}


}

case class GpuIf(
Expand Down Expand Up @@ -203,59 +258,6 @@ case class GpuIf(
}
}

private def filterBatch(
tbl: Table,
pred: ColumnVector,
colTypes: Array[DataType]): ColumnarBatch = {
withResource(tbl.filter(pred)) { filteredData =>
GpuColumnVector.from(filteredData, colTypes)
}
}

private def boolToInt(cv: ColumnVector): ColumnVector = {
withResource(GpuScalar.from(1, DataTypes.IntegerType)) { one =>
withResource(GpuScalar.from(0, DataTypes.IntegerType)) { zero =>
cv.ifElse(one, zero)
}
}
}

private def gather(predicate: ColumnVector, t: GpuColumnVector): ColumnVector = {
// convert the predicate boolean column to numeric where 1 = true
// amd 0 = false and then use `scan` with `sum` to convert to
// indices.
//
// For example, if the predicate evaluates to [F, F, T, F, T] then this
// gets translated first to [0, 0, 1, 0, 1] and then the scan operation
// will perform an exclusive sum on these values and
// produce [0, 0, 0, 1, 1]. Combining this with the original
// predicate boolean array results in the two T values mapping to
// indices 0 and 1, respectively.

withResource(boolToInt(predicate)) { boolsAsInts =>
withResource(boolsAsInts.scan(
ScanAggregation.sum(),
ScanType.EXCLUSIVE,
NullPolicy.INCLUDE)) { prefixSumExclusive =>

// for the entries in the gather map that do not represent valid
// values to be gathered, we change the value to -MAX_INT which
// will be treated as null values in the gather algorithm
val gatherMap = withResource(Scalar.fromInt(Int.MinValue)) {
outOfBoundsFlag => predicate.ifElse(prefixSumExclusive, outOfBoundsFlag)
}

withResource(new Table(t.getBase)) { tbl =>
withResource(gatherMap) { _ =>
withResource(tbl.gather(gatherMap)) { gatherTbl =>
gatherTbl.getColumn(0).incRefCount()
}
}
}
}
}
}

override def toString: String = s"if ($predicateExpr) $trueExpr else $falseExpr"

override def sql: String = s"(IF(${predicateExpr.sql}, ${trueExpr.sql}, ${falseExpr.sql}))"
Expand All @@ -274,6 +276,9 @@ case class GpuCaseWhen(
branches.map(_._2.dataType) ++ elseValue.map(_.dataType)
}

private lazy val branchesWithSideEffects =
branches.exists(_._2.asInstanceOf[GpuExpression].hasSideEffects)

override def nullable: Boolean = {
// Result is nullable if any of the branch is nullable, or if the else value is nullable
branches.exists(_._2.nullable) || elseValue.forall(_.nullable)
Expand Down Expand Up @@ -301,12 +306,139 @@ case class GpuCaseWhen(
}

override def columnarEval(batch: ColumnarBatch): Any = {
// `elseRet` will be closed in `computeIfElse`.
val elseRet = elseValue
.map(_.columnarEval(batch))
.getOrElse(GpuScalar(null, branches.last._2.dataType))
branches.foldRight[Any](elseRet) { case ((predicateExpr, trueExpr), falseRet) =>
computeIfElse(batch, predicateExpr, trueExpr, falseRet)
if (branchesWithSideEffects) {
columnarEvalWithSideEffects(batch)
} else {
// `elseRet` will be closed in `computeIfElse`.
val elseRet = elseValue
.map(_.columnarEval(batch))
.getOrElse(GpuScalar(null, branches.last._2.dataType))
branches.foldRight[Any](elseRet) {
case ((predicateExpr, trueExpr), falseRet) =>
computeIfElse(batch, predicateExpr, trueExpr, falseRet)
}
}
}
andygrove marked this conversation as resolved.
Show resolved Hide resolved

/**
* Perform lazy evaluation of each branch sa that we only evaluate the THEN expressions
andygrove marked this conversation as resolved.
Show resolved Hide resolved
* against rows where the WHEN expression is true.
*/
private def columnarEvalWithSideEffects(batch: ColumnarBatch): Any = {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
val colTypes = GpuColumnVector.extractTypes(batch)

// track cumulative state of predicate evaluation per row so that we never evaluate expressions
// for a row if an earlier expression has already been evaluated to true for that row
var cumulativePred: Option[GpuColumnVector] = None

// this variable contains the currently evaluated value for each row and gets updated
// as each branch is evaluated
var currentValue: Option[GpuColumnVector] = None

try {
withResource(GpuColumnVector.from(batch)) { tbl =>

// iterate over the WHEN THEN branches first
branches.foreach {
case (whenExpr, thenExpr) =>
// evaluate the WHEN predicate
withResource(GpuExpressionsUtils.columnarEvalToColumn(whenExpr, batch)) { whenBool =>
// we only want to evaluate where this WHEN is true and no previous WHEN has been true
val whenBoolAdjusted: GpuColumnVector = cumulativePred match {
case Some(prev) =>
withResource(prev.getBase.not()) { notPreviouslyTrue =>
jlowe marked this conversation as resolved.
Show resolved Hide resolved
GpuColumnVector.from(
whenBool.getBase.and(notPreviouslyTrue), DataTypes.BooleanType)
}
case None =>
whenBool.incRefCount()
}

withResource(whenBoolAdjusted) { _ =>
if (isAllTrue(whenBoolAdjusted)) {
// if this WHEN predicate is true for all rows and no previous predicate has
// been true then we can return immediately
return GpuExpressionsUtils.columnarEvalToColumn(thenExpr, batch)
}
val thenValues = withResource(filterBatch(tbl,
whenBoolAdjusted.getBase, colTypes)) {
trueBatch => GpuExpressionsUtils.columnarEvalToColumn(thenExpr, trueBatch)
}
withResource(thenValues) { _ =>
withResource(gather(whenBoolAdjusted.getBase, thenValues)) { gather =>
jlowe marked this conversation as resolved.
Show resolved Hide resolved
currentValue = Some(currentValue match {
case Some(v) =>
withResource(v) { _ =>
GpuColumnVector.from(whenBoolAdjusted.getBase.ifElse(gather, v.getBase),
dataType)
}
case _ =>
GpuColumnVector.from(gather.incRefCount(), dataType)
})
}
jlowe marked this conversation as resolved.
Show resolved Hide resolved

cumulativePred = Some(cumulativePred match {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
case Some(prev) =>
withResource(prev) { _ =>
GpuColumnVector.from(whenBool.getBase.or(prev.getBase),
DataTypes.BooleanType)
}
case _ =>
whenBoolAdjusted.incRefCount()
})

if (isAllTrue(cumulativePred.get)) {
// no need to process any more branches or the else condition
return currentValue.get.incRefCount()
}
}
}
}
}

// invert the cumulative predicate to get the ELSE predicate
withResource(cumulativePred.get.getBase.not()) { elsePredicate =>
jlowe marked this conversation as resolved.
Show resolved Hide resolved
// replace null predicates with true (because this is an inverted predicate)
val elsePredNoNulls = withResource(Scalar.fromBool(true)) { t =>
jlowe marked this conversation as resolved.
Show resolved Hide resolved
GpuColumnVector.from(elsePredicate.replaceNulls(t), DataTypes.BooleanType)
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}
withResource(elsePredNoNulls) { _ =>
elseValue match {
case Some(expr) =>
if (isAllTrue(elsePredNoNulls)) {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
GpuExpressionsUtils.columnarEvalToColumn(expr, batch)
} else {
val elseValues = withResource(
filterBatch(tbl, elsePredNoNulls.getBase, colTypes)) {
elseBatch => GpuExpressionsUtils.columnarEvalToColumn(expr, elseBatch)
}
withResource(elseValues) { _ =>
withResource(gather(elsePredNoNulls.getBase, elseValues)) { gather =>
jlowe marked this conversation as resolved.
Show resolved Hide resolved
GpuColumnVector.from(elsePredNoNulls.getBase.ifElse(
gather, currentValue.get.getBase), dataType)
}
}
}

case None =>
// if there is no ELSE condition then we return NULL for any rows not matched by
// previous branches
withResource(GpuScalar.from(null, dataType)) { nullScalar =>
if (isAllTrue(elsePredNoNulls) && elsePredNoNulls.getRowCount <= Int.MaxValue) {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
GpuColumnVector.from(nullScalar, elsePredNoNulls.getRowCount.toInt, dataType)
} else {
GpuColumnVector.from(
elsePredNoNulls.getBase.ifElse(nullScalar, currentValue.get.getBase),
dataType)
}
}
}
}
}
}
} finally {
currentValue.foreach(_.safeClose())
cumulativePred.foreach(_.safeClose())
}
}

Expand Down
Loading