Skip to content

Commit

Permalink
Improve the retry support for nondeterministic expressions (#11789)
Browse files Browse the repository at this point in the history
Contributes to #11649

This PR is trying to address some requirements described in issue #11649, but not all of them.

It introduces two new classes named "GpuExpressionRetryable" and "RetryStateTracker" to initially
set up a fundamental support to detect the context requirement for nondeterministic expressions,
and adds in the relevant unit tests.

And it also adds the integration tests for the function "rand()" being used in HashAggregate,
Generate, Projection, ArrowEvalPython and Filter. It still does not cover all the cases where
a nondeterministic expression can be used, but we are closer than before.

---------

Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman authored Dec 23, 2024
1 parent 01f9fd2 commit 32aa3e1
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 26 deletions.
85 changes: 85 additions & 0 deletions integration_tests/src/main/python/rand_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from marks import *
from spark_session import is_before_spark_351

import pyspark.sql.functions as f


@ignore_order(local=True)
@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
def test_group_agg_with_rand():
# GPU and CPU produce the same grouping rows but in different orders after Shuffle,
# while the rand() always generates the same sequence. Then CPU and GPU will produce
# different final rows after aggregation. See as below:
# GPU output:
# +---+-------------------+
# | a| random|
# +---+-------------------+
# | 3| 0.619189370225301|
# | 5| 0.5096018842446481|
# | 2| 0.8325259388871524|
# | 4|0.26322809041172357|
# | 1| 0.6702867696264135|
# +---+-------------------+
# CPU output:
# +---+-------------------+
# | a| random|
# +---+-------------------+
# | 1| 0.619189370225301|
# | 2| 0.5096018842446481|
# | 3| 0.8325259388871524|
# | 4|0.26322809041172357|
# | 5| 0.6702867696264135|
# +---+-------------------+
# To make the output comparable, here builds a generator to generate only one group.
const_int_gen = IntegerGen(nullable=False, min_val=1, max_val=1, special_cases=[])

def test(spark):
return unary_op_df(spark, const_int_gen, num_slices=1).groupby('a').agg(f.rand(42))
assert_gpu_and_cpu_are_equal_collect(test)


@ignore_order(local=True)
def test_project_with_rand():
# To make the output comparable, here build a generator to generate only one value.
# Not sure if Project could have the same order issue as groupBy, but still just in case.
const_int_gen = IntegerGen(nullable=False, min_val=1, max_val=1, special_cases=[])
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, const_int_gen, num_slices=1).select('a', f.rand(42))
)


@ignore_order(local=True)
def test_filter_with_rand():
const_int_gen = IntegerGen(nullable=False, min_val=1, max_val=1, special_cases=[])
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, const_int_gen, num_slices=1).filter(f.rand(42) > 0.5)
)

# See https://github.com/apache/spark/commit/9c0b803ba124a6e70762aec1e5559b0d66529f4d
@ignore_order(local=True)
@pytest.mark.skipif(is_before_spark_351(),
reason='Generate supports nondeterministic inputs from Spark 3.5.1')
def test_generate_with_rand():
const_int_gen = IntegerGen(nullable=False, min_val=1, max_val=1, special_cases=[])
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, const_int_gen, num_slices=1).select(
f.explode(f.array(f.rand(42))))
)
1 change: 1 addition & 0 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _from_scala_map(scala_map):
# Many of these are redundant with default settings for the configs but are set here explicitly
# to ensure any cluster settings do not interfere with tests that assume the defaults.
_default_conf = {
'spark.rapids.sql.test.retryContextCheck.enabled': 'true',
'spark.rapids.sql.castDecimalToFloat.enabled': 'false',
'spark.rapids.sql.castFloatToDecimal.enabled': 'false',
'spark.rapids.sql.castFloatToIntegralTypes.enabled': 'false',
Expand Down
18 changes: 17 additions & 1 deletion integration_tests/src/main/python/udf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pyspark import BarrierTaskContext, TaskContext

from conftest import is_at_least_precommit_run, is_databricks_runtime
from spark_session import is_before_spark_330, is_before_spark_350, is_spark_341
from spark_session import is_before_spark_330, is_before_spark_331, is_before_spark_350, is_spark_341

from pyspark.sql.pandas.utils import require_minimum_pyarrow_version, require_minimum_pandas_version

Expand Down Expand Up @@ -474,3 +474,19 @@ def add_one(a):
lambda spark: unary_op_df(spark, int_gen, num_slices=4, length=52345)
.select(my_udf(f.lit(0))),
conf=arrow_udf_conf)


# Python UDFs support nondeterministic expressions from Spark 3.3.1.
# See https://github.com/apache/spark/commit/1a01a492c051bb861c480f224a3c310e133e4d01
@ignore_order(local=True)
@pytest.mark.skipif(is_before_spark_331(),
reason='Nondeterministic expressions are supported from Spark 3.3.1')
def test_pandas_math_udf_with_rand():
def add(rand_value):
return rand_value
my_udf = f.pandas_udf(add, returnType=IntegerType())
assert_gpu_and_cpu_are_equal_collect(
# Ensure there is only one partition to make the output comparable.
lambda spark: unary_op_df(spark, int_gen, length=10, num_slices=1).select(
my_udf(f.rand(42))),
conf=arrow_udf_conf)
Original file line number Diff line number Diff line change
Expand Up @@ -816,13 +816,13 @@ object GpuAggFinalPassIterator {
boundResultReferences)
}

private[this] def reorderFinalBatch(finalBatch: ColumnarBatch,
private[this] def reorderFinalBatch(finalBatch: SpillableColumnarBatch,
boundExpressions: BoundExpressionsModeAggregates,
metrics: GpuHashAggregateMetrics): ColumnarBatch = {
// Perform the last project to get the correct shape that Spark expects. Note this may
// add things like literals that were not part of the aggregate into the batch.
closeOnExcept(GpuProjectExec.projectAndClose(finalBatch,
boundExpressions.boundResultReferences, NoopMetric)) { ret =>
closeOnExcept(GpuProjectExec.projectAndCloseWithRetrySingleBatch(finalBatch,
boundExpressions.boundResultReferences)) { ret =>
metrics.numOutputRows += ret.numRows()
metrics.numOutputBatches += 1
ret
Expand All @@ -838,9 +838,12 @@ object GpuAggFinalPassIterator {
withResource(new NvtxWithMetrics("finalize agg", NvtxColor.DARK_GREEN, aggTime,
opTime)) { _ =>
val finalBatch = boundExpressions.boundFinalProjections.map { exprs =>
GpuProjectExec.projectAndClose(batch, exprs, NoopMetric)
GpuProjectExec.projectAndCloseWithRetrySingleBatch(
SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_BATCHING_PRIORITY), exprs)
}.getOrElse(batch)
reorderFinalBatch(finalBatch, boundExpressions, metrics)
val finalSCB =
SpillableColumnarBatch(finalBatch, SpillPriorities.ACTIVE_BATCHING_PRIORITY)
reorderFinalBatch(finalSCB, boundExpressions, metrics)
}
}
}
Expand All @@ -854,12 +857,10 @@ object GpuAggFinalPassIterator {
withResource(new NvtxWithMetrics("finalize agg", NvtxColor.DARK_GREEN, aggTime,
opTime)) { _ =>
val finalBatch = boundExpressions.boundFinalProjections.map { exprs =>
GpuProjectExec.projectAndCloseWithRetrySingleBatch(sb, exprs)
}.getOrElse {
withRetryNoSplit(sb) { _ =>
sb.getColumnarBatch()
}
}
SpillableColumnarBatch(
GpuProjectExec.projectAndCloseWithRetrySingleBatch(sb, exprs),
SpillPriorities.ACTIVE_BATCHING_PRIORITY)
}.getOrElse(sb)
reorderFinalBatch(finalBatch, boundExpressions, metrics)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2432,7 +2432,8 @@ object GpuOverrides extends Logging {
(TypeSig.INT + TypeSig.LONG).withAllLit(),
(TypeSig.INT + TypeSig.LONG).withAllLit()))),
(a, conf, p, r) => new UnaryExprMeta[Rand](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = GpuRand(child)
override def convertToGpu(child: Expression): GpuExpression =
GpuRand(child, this.conf.isRetryContextCheckEnabled)
}),
expr[SparkPartitionID] (
"Returns the current partition id",
Expand Down
11 changes: 11 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1643,6 +1643,15 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.booleanConf
.createWithDefault(false)

val TEST_RETRY_CONTEXT_CHECK_ENABLED = conf("spark.rapids.sql.test.retryContextCheck.enabled")
.doc("Only to be used in tests. When set to true, enable the context check for " +
"GPU nondeterministic expressions but declaring to be retryable. A GPU retryable " +
"nondeterministic expression should run inside a checkpoint-restore context. And it " +
"will blow up when the context does not satisfy.")
.internal()
.booleanConf
.createWithDefault(false)

val TEST_CONF = conf("spark.rapids.sql.test.enabled")
.doc("Intended to be used by unit tests, if enabled all operations must run on the " +
"GPU or an error happens.")
Expand Down Expand Up @@ -2733,6 +2742,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val isTestEnabled: Boolean = get(TEST_CONF)

lazy val isRetryContextCheckEnabled: Boolean = get(TEST_RETRY_CONTEXT_CHECK_ENABLED)

lazy val isFoldableNonLitAllowed: Boolean = get(FOLDABLE_NON_LIT_ALLOWED)

lazy val asyncWriteMaxInFlightHostMemoryBytes: Long =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ object RmmRapidsRetryIterator extends Logging {
var doSplit = false
var isFromGpuOom = true
while (result.isEmpty && attemptIter.hasNext) {
RetryStateTracker.setCurThreadRetrying(!firstAttempt)
if (!firstAttempt) {
// call thread block API
try {
Expand Down Expand Up @@ -685,6 +686,7 @@ object RmmRapidsRetryIterator extends Logging {
// else another exception wrapped a retry. So we are going to try again
}
}
RetryStateTracker.clearCurThreadRetrying()
if (result.isEmpty) {
// then lastException must be set, throw it.
throw lastException
Expand Down Expand Up @@ -791,3 +793,21 @@ object RmmRapidsRetryIterator extends Logging {
case class AutoCloseableTargetSize(targetSize: Long, minSize: Long) extends AutoCloseable {
override def close(): Unit = ()
}

/**
* This leverages a ThreadLocal of boolean to track if a task thread is currently
* executing a retry. And the boolean state will be used by all the
* `GpuExpressionRetryable`s to determine if the context is safe to retry the evaluation.
*/
object RetryStateTracker {
private val localIsRetrying = new ThreadLocal[java.lang.Boolean]()

def isCurThreadRetrying: Boolean = {
val ret = localIsRetrying.get()
ret != null && ret
}

def setCurThreadRetrying(retrying: Boolean): Unit = localIsRetrying.set(retrying)

def clearCurThreadRetrying(): Unit = localIsRetrying.remove()
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.catalyst.expressions

import ai.rapids.cudf.{DType, HostColumnVector, NvtxColor, NvtxRange}
import com.nvidia.spark.Retryable
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuLiteral}
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuLiteral, RetryStateTracker}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.shims.ShimUnaryExpression

Expand All @@ -30,13 +30,51 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils
import org.apache.spark.util.random.rapids.RapidsXORShiftRandom

/**
* An expression expected to be evaluated inside a retry with checkpoint-restore context.
* It will throw an exception if it is retried without being checkpointed.
* All the nondeterministic GPU expressions that support Retryable should extend from
* this trait.
*/
trait GpuExpressionRetryable extends GpuExpression with Retryable {
private var checked = false

def doColumnarEval(batch: ColumnarBatch): GpuColumnVector
def doCheckpoint(): Unit
def doRestore(): Unit

def doContextCheck(): Boolean // For tests

override final def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
if (doContextCheck && !checked) { // This is for tests
throw new IllegalStateException(
"The Retryable was called outside of a checkpoint-restore context")
}
if (!checked && RetryStateTracker.isCurThreadRetrying) {
// It is retrying the evaluation without checkpointing, which is not allowed.
throw new IllegalStateException(
"The Retryable should be retried only inside a checkpoint-restore context")
}
doColumnarEval(batch)
}

override final def checkpoint(): Unit = {
checked = true
doCheckpoint()
}

override final def restore(): Unit = doRestore()
}

/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
case class GpuRand(child: Expression) extends ShimUnaryExpression with GpuExpression
with ExpectsInputTypes with ExpressionWithRandomSeed with Retryable {
case class GpuRand(child: Expression, doContextCheck: Boolean) extends ShimUnaryExpression
with ExpectsInputTypes with ExpressionWithRandomSeed with GpuExpressionRetryable {

def this() = this(GpuLiteral(Utils.random.nextLong(), LongType))
def this(doContextCheck: Boolean) = this(GpuLiteral(Utils.random.nextLong(), LongType),
doContextCheck)

override def withNewSeed(seed: Long): GpuRand = GpuRand(GpuLiteral(seed, LongType))
override def withNewSeed(seed: Long): GpuRand = GpuRand(GpuLiteral(seed, LongType),
doContextCheck)

def seedExpression: Expression = child

Expand Down Expand Up @@ -76,7 +114,7 @@ case class GpuRand(child: Expression) extends ShimUnaryExpression with GpuExpres
}
}

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
override def doColumnarEval(batch: ColumnarBatch): GpuColumnVector = {
if (curXORShiftRandomSeed.isEmpty) {
// checkpoint not called, need to init the random generator here
initRandom()
Expand All @@ -93,14 +131,14 @@ case class GpuRand(child: Expression) extends ShimUnaryExpression with GpuExpres
}
}

override def checkpoint(): Unit = {
override def doCheckpoint(): Unit = {
// In a task, checkpoint is called before columnarEval, so need to try to
// init the random generator here.
initRandom()
curXORShiftRandomSeed = Some(rng.currentSeed)
}

override def restore(): Unit = {
override def doRestore(): Unit = {
assert(wasInitialized && curXORShiftRandomSeed.isDefined)
rng.setHashedSeed(curXORShiftRandomSeed.get)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,9 @@ case class GpuArrowEvalPythonExec(
new RebatchingRoundoffIterator(iter, inputSchema, targetBatchSize, numInputRows,
numInputBatches))
val pyInputIterator = batchProducer.asIterator.map { batch =>
withResource(batch)(GpuProjectExec.project(_, boundReferences))
GpuProjectExec.projectAndCloseWithRetrySingleBatch(
SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_BATCHING_PRIORITY),
boundReferences)
}

if (isPythonOnGpuEnabled) {
Expand Down
Loading

0 comments on commit 32aa3e1

Please sign in to comment.