Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change order of initialization so pinned pool is available for spill framework buffers #11889

Open
wants to merge 3 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
* Copyright (c) 2020-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -284,10 +284,29 @@ object GpuDeviceManager extends Logging {

private var memoryEventHandler: DeviceMemoryEventHandler = _

private def initializeRmm(gpuId: Int, rapidsConf: Option[RapidsConf]): Unit = {
if (!Rmm.isInitialized) {
val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf))
private def initializeSpillAndMemoryEvents(conf: RapidsConf): Unit = {
SpillFramework.initialize(conf)

memoryEventHandler = new DeviceMemoryEventHandler(
SpillFramework.stores.deviceStore,
conf.gpuOomDumpDir,
conf.gpuOomMaxRetries)

if (conf.sparkRmmStateEnable) {
val debugLoc = if (conf.sparkRmmDebugLocation.isEmpty) {
null
} else {
conf.sparkRmmDebugLocation
}
RmmSpark.setEventHandler(memoryEventHandler, debugLoc)
} else {
logWarning("SparkRMM retry has been disabled")
Rmm.setEventHandler(memoryEventHandler)
}
}

private def initializeRmmGpuPool(gpuId: Int, conf: RapidsConf): Unit = {
if (!Rmm.isInitialized) {
val poolSize = conf.chunkedPackPoolSize
chunkedPackMemoryResource =
if (poolSize > 0) {
Expand Down Expand Up @@ -391,30 +410,10 @@ object GpuDeviceManager extends Logging {
}
}

SpillFramework.initialize(conf)

memoryEventHandler = new DeviceMemoryEventHandler(
SpillFramework.stores.deviceStore,
conf.gpuOomDumpDir,
conf.gpuOomMaxRetries)

if (conf.sparkRmmStateEnable) {
val debugLoc = if (conf.sparkRmmDebugLocation.isEmpty) {
null
} else {
conf.sparkRmmDebugLocation
}
RmmSpark.setEventHandler(memoryEventHandler, debugLoc)
} else {
logWarning("SparkRMM retry has been disabled")
Rmm.setEventHandler(memoryEventHandler)
}
GpuShuffleEnv.init(conf)
}
}

private def initializeOffHeapLimits(gpuId: Int, rapidsConf: Option[RapidsConf]): Unit = {
val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf))
private def initializePinnedPoolAndOffHeapLimits(gpuId: Int, conf: RapidsConf): Unit = {
val setCuioDefaultResource = conf.pinnedPoolCuioDefault
val (pinnedSize, nonPinnedLimit) = if (conf.offHeapLimitEnabled) {
logWarning("OFF HEAP MEMORY LIMITS IS ENABLED. " +
Expand Down Expand Up @@ -508,8 +507,13 @@ object GpuDeviceManager extends Logging {
"Cannot initialize memory due to previous shutdown failing")
} else if (singletonMemoryInitialized == Uninitialized) {
val gpu = gpuId.getOrElse(findGpuAndAcquire())
initializeRmm(gpu, rapidsConf)
initializeOffHeapLimits(gpu, rapidsConf)
val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf))
initializePinnedPoolAndOffHeapLimits(gpu, conf)
initializeRmmGpuPool(gpu, conf)
// we want to initialize this last because we want to take advantage
// of pinned memory if it is configured
initializeSpillAndMemoryEvents(conf)
GpuShuffleEnv.init(conf)
singletonMemoryInitialized = Initialized
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -131,18 +131,23 @@ object GpuShuffleEnv extends Logging {
def isRapidsShuffleAvailable(conf: RapidsConf): Boolean = {
// the driver has `mgr` defined when this is checked
val sparkEnv = SparkEnv.get
val isRapidsManager = sparkEnv.shuffleManager.isInstanceOf[RapidsShuffleManagerLike]
if (isRapidsManager) {
validateRapidsShuffleManager(sparkEnv.shuffleManager.getClass.getName)
if (sparkEnv == null) {
// we may hit this in some tests that don't need to use the RAPIDS shuffle manager.
false
} else {
val isRapidsManager = sparkEnv.shuffleManager.isInstanceOf[RapidsShuffleManagerLike]
if (isRapidsManager) {
validateRapidsShuffleManager(sparkEnv.shuffleManager.getClass.getName)
}
// executors have `env` defined when this is checked
// in tests
val isConfiguredInEnv = Option(env).exists(_.isRapidsShuffleConfigured)
(isConfiguredInEnv || isRapidsManager) &&
(conf.isMultiThreadedShuffleManagerMode ||
(conf.isGPUShuffle && !isExternalShuffleEnabled &&
!isSparkAuthenticateEnabled)) &&
conf.isSqlExecuteOnGPU
}
// executors have `env` defined when this is checked
// in tests
val isConfiguredInEnv = Option(env).exists(_.isRapidsShuffleConfigured)
(isConfiguredInEnv || isRapidsManager) &&
(conf.isMultiThreadedShuffleManagerMode ||
(conf.isGPUShuffle && !isExternalShuffleEnabled &&
!isSparkAuthenticateEnabled)) &&
conf.isSqlExecuteOnGPU
}

def useGPUShuffle(conf: RapidsConf): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* Unit tests for utility methods in [[ BatchWithPartitionDataUtils ]]
*/
class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase with SparkQueryCompareTestSuite {
class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase {

test("test splitting partition data into groups") {
val maxGpuColumnSizeBytes = 1000L
Expand All @@ -55,48 +55,46 @@ class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase with SparkQuery
// This test uses single-row partition values that should throw a GpuSplitAndRetryOOM exception
// when a retry is forced.
val maxGpuColumnSizeBytes = 1000L
withGpuSparkSession(_ => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because of the RMM is initialized twice thing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly. It was only needed because this test wanted SparkSession to be defined, such that other code in the init didn't NPE (it didn't need Spark at all otherwise).

val (_, partValues, _, partSchema) = getSamplePartitionData
closeOnExcept(buildBatch(getSampleValueData)) { valueBatch =>
val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch,
Array(1), partValues.take(1), partSchema, maxGpuColumnSizeBytes)
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
withResource(resultBatchIter) { _ =>
assertThrows[GpuSplitAndRetryOOM] {
resultBatchIter.next()
}
val (_, partValues, _, partSchema) = getSamplePartitionData
closeOnExcept(buildBatch(getSampleValueData)) { valueBatch =>
val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch,
Array(1), partValues.take(1), partSchema, maxGpuColumnSizeBytes)
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
withResource(resultBatchIter) { _ =>
assertThrows[GpuSplitAndRetryOOM] {
resultBatchIter.next()
}
}
})
}
}

test("test adding partition values to batch with OOM split and retry") {
// This test should split the input batch and process them when a retry is forced.
val maxGpuColumnSizeBytes = 1000L
withGpuSparkSession(_ => {
val (partCols, partValues, partRows, partSchema) = getSamplePartitionData
withResource(buildBatch(getSampleValueData)) { valueBatch =>
withResource(buildBatch(partCols)) { partBatch =>
withResource(GpuColumnVector.combineColumns(valueBatch, partBatch)) { expectedBatch =>
// we incRefCounts here because `addPartitionValuesToBatch` takes ownership of
// `valueBatch`, but we are keeping it alive since its columns are part of
// `expectedBatch`
GpuColumnVector.incRefCounts(valueBatch)
val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch,
partRows, partValues, partSchema, maxGpuColumnSizeBytes)
withResource(resultBatchIter) { _ =>
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
// Assert that the final count of rows matches expected batch
// We also need to close each batch coming from `resultBatchIter`.
val rowCounts = resultBatchIter.map(withResource(_){_.numRows()}).sum
assert(rowCounts == expectedBatch.numRows())
}
val (partCols, partValues, partRows, partSchema) = getSamplePartitionData
withResource(buildBatch(getSampleValueData)) { valueBatch =>
withResource(buildBatch(partCols)) { partBatch =>
withResource(GpuColumnVector.combineColumns(valueBatch, partBatch)) { expectedBatch =>
// we incRefCounts here because `addPartitionValuesToBatch` takes ownership of
// `valueBatch`, but we are keeping it alive since its columns are part of
// `expectedBatch`
GpuColumnVector.incRefCounts(valueBatch)
val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch,
partRows, partValues, partSchema, maxGpuColumnSizeBytes)
withResource(resultBatchIter) { _ =>
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
// Assert that the final count of rows matches expected batch
// We also need to close each batch coming from `resultBatchIter`.
val rowCounts = resultBatchIter.map(withResource(_) {
_.numRows()
}).sum
assert(rowCounts == expectedBatch.numRows())
}
}
}
})
}
}

private def getSamplePartitionData: (Array[Array[String]], Array[InternalRow], Array[Long],
Expand Down Expand Up @@ -140,4 +138,4 @@ class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase with SparkQuery
GpuColumnVector.from(ColumnVector.fromStrings(v: _*), StringType))
new ColumnarBatch(colVectors.toArray, numRows)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,7 +20,6 @@ import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.RmmSpark

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, ExprId, SortOrder, SpecificInternalRow}
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -36,49 +35,45 @@ class ShufflePartitionerRetrySuite extends RmmSparkRetrySuiteBase {
}

private def testRoundRobinPartitioner(partNum: Int) = {
TestUtils.withGpuSparkSession(new SparkConf()) { _ =>
val rrp = GpuRoundRobinPartitioning(partNum)
// batch will be closed within columnarEvalAny
val batch = buildBatch
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
var ret: Array[(ColumnarBatch, Int)] = null
try {
ret = rrp.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]]
assert(partNum === ret.size)
} finally {
if (ret != null) {
ret.map(_._1).safeClose()
}
val rrp = GpuRoundRobinPartitioning(partNum)
// batch will be closed within columnarEvalAny
val batch = buildBatch
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
var ret: Array[(ColumnarBatch, Int)] = null
try {
ret = rrp.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]]
assert(partNum === ret.size)
} finally {
if (ret != null) {
ret.map(_._1).safeClose()
}
}
}

test("GPU range partition with retry") {
TestUtils.withGpuSparkSession(new SparkConf()) { _ =>
// Initialize range bounds
val fieldTypes: Array[DataType] = Array(IntegerType)
val bounds = new SpecificInternalRow(fieldTypes)
bounds.setInt(0, 3)
// Initialize GPU sorter
val ref = GpuBoundReference(0, IntegerType, nullable = true)(ExprId(0), "a")
val sortOrder = SortOrder(ref, Ascending)
val attrs = AttributeReference(ref.name, ref.dataType, ref.nullable)()
val gpuSorter = new GpuSorter(Seq(sortOrder), Array(attrs))
// Initialize range bounds
val fieldTypes: Array[DataType] = Array(IntegerType)
val bounds = new SpecificInternalRow(fieldTypes)
bounds.setInt(0, 3)
// Initialize GPU sorter
val ref = GpuBoundReference(0, IntegerType, nullable = true)(ExprId(0), "a")
val sortOrder = SortOrder(ref, Ascending)
val attrs = AttributeReference(ref.name, ref.dataType, ref.nullable)()
val gpuSorter = new GpuSorter(Seq(sortOrder), Array(attrs))

val rp = GpuRangePartitioner(Array.apply(bounds), gpuSorter)
// batch will be closed within columnarEvalAny
val batch = buildBatch
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
var ret: Array[(ColumnarBatch, Int)] = null
try {
ret = rp.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]]
assert(ret.length === 2)
} finally {
if (ret != null) {
ret.map(_._1).safeClose()
}
val rp = GpuRangePartitioner(Array.apply(bounds), gpuSorter)
// batch will be closed within columnarEvalAny
val batch = buildBatch
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
var ret: Array[(ColumnarBatch, Int)] = null
try {
ret = rp.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]]
assert(ret.length === 2)
} finally {
if (ret != null) {
ret.map(_._1).safeClose()
}
}
}
Expand Down
Loading
Loading