Skip to content

Commit

Permalink
Integrate with kudo
Browse files Browse the repository at this point in the history
  • Loading branch information
liurenjie1024 committed Nov 15, 2024
1 parent a8010cc commit 54b8dad
Show file tree
Hide file tree
Showing 11 changed files with 442 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,14 @@ public static StructType structFromAttributes(List<Attribute> format) {
return new StructType(fields);
}

public static DataType[] dataTypesFromAttributes(List<Attribute> schema) {
DataType[] types = new DataType[schema.size()];
for (int i = 0; i < schema.size(); i++) {
types[i] = schema.get(i).dataType();
}
return types;
}

/**
* Convert a Spark schema into a cudf schema
* @param input the Spark schema to convert
Expand All @@ -528,6 +536,56 @@ public static Schema from(StructType input) {
return builder.build();
}

/**
* Converts a list of spark data types to a cudf schema.
*
* <br/>
*
* This method correctly handles nested types, but will generate random field names.
*
* @param dataTypes the list of data types to convert
* @return the cudf schema
*/
public static Schema from(DataType[] dataTypes) {
Schema.Builder builder = Schema.builder();
visit(dataTypes, builder, 0);
return builder.build();
}

private static void visit(DataType[] dataTypes, Schema.Builder builder, int level) {
for (int idx = 0; idx < dataTypes.length; idx ++) {
DataType dt = dataTypes[idx];
String name = "_col_" + level + "_" + idx;
if (dt instanceof MapType) {
// MapType is list of struct in cudf, so need to handle it specially.
Schema.Builder listBuilder = builder.addColumn(DType.LIST, name);
Schema.Builder structBuilder = listBuilder.addColumn(DType.STRUCT, name + "_map");
MapType mt = (MapType) dt;
DataType[] structChildren = {mt.keyType(), mt.valueType()};
visit(structChildren, structBuilder, level + 1);
} else if (dt instanceof BinaryType) {
Schema.Builder listBuilder = builder.addColumn(DType.LIST, name);
listBuilder.addColumn(DType.UINT8, name + "_bytes");
} else {
Schema.Builder childBuilder = builder.addColumn(GpuColumnVector.getRapidsType(dt), name);
if (dt instanceof ArrayType) {
// Array (aka List)
DataType[] childType = {((ArrayType) dt).elementType()};
visit(childType, childBuilder, level + 1);
} else if (dt instanceof StructType) {
// Struct
StructType st = (StructType) dt;
DataType[] childrenTypes = new DataType[st.length()];
for (int i = 0; i < childrenTypes.length; i ++) {
childrenTypes[i] = st.apply(i).dataType();
}
visit(childrenTypes, childBuilder, level + 1);
}
}
}
}


/**
* Convert a ColumnarBatch to a table. The table will increment the reference count for all of
* the columns in the batch, so you will need to close both the batch passed in and the table
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,10 +27,11 @@ import ai.rapids.cudf.JCudfSerialization.SerializedTableHeader
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
import com.nvidia.spark.rapids.jni.kudo.{KudoSerializer, KudoTable}

import org.apache.spark.TaskContext
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance}
import org.apache.spark.sql.types.NullType
import org.apache.spark.sql.types.{DataType, NullType}
import org.apache.spark.sql.vectorized.ColumnarBatch

class SerializedBatchIterator(dIn: DataInputStream)
Expand All @@ -49,7 +50,7 @@ class SerializedBatchIterator(dIn: DataInputStream)
}

def tryReadNextHeader(): Option[Long] = {
if (streamClosed){
if (streamClosed) {
None
} else {
if (nextHeader.isEmpty) {
Expand Down Expand Up @@ -108,6 +109,7 @@ class SerializedBatchIterator(dIn: DataInputStream)
(0, ret)
}
}

/**
* Serializer for serializing `ColumnarBatch`s for use during normal shuffle.
*
Expand All @@ -124,10 +126,16 @@ class SerializedBatchIterator(dIn: DataInputStream)
*
* @note The RAPIDS shuffle does not use this code.
*/
class GpuColumnarBatchSerializer(dataSize: GpuMetric)
extends Serializer with Serializable {
override def newInstance(): SerializerInstance =
new GpuColumnarBatchSerializerInstance(dataSize)
class GpuColumnarBatchSerializer(dataSize: GpuMetric, dataTypes: Array[DataType], useKudo: Boolean)
extends Serializer with Serializable {
override def newInstance(): SerializerInstance = {
if (useKudo) {
new KudoSerializerInstance(dataSize, dataTypes)
} else {
new GpuColumnarBatchSerializerInstance(dataSize)
}
}

override def supportsRelocationOfSerializedObjects: Boolean = true
}

Expand Down Expand Up @@ -252,8 +260,10 @@ private class GpuColumnarBatchSerializerInstance(dataSize: GpuMetric) extends Se

// These methods are never called by shuffle code.
override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException

override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
throw new UnsupportedOperationException

override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
throw new UnsupportedOperationException
}
Expand Down Expand Up @@ -282,7 +292,7 @@ object SerializedTableColumn {
* Build a `ColumnarBatch` consisting of a single [[SerializedTableColumn]] describing
* the specified serialized table.
*
* @param header header for the serialized table
* @param header header for the serialized table
* @param hostBuffer host buffer containing the table data
* @return columnar batch to be passed to [[GpuShuffleCoalesceExec]]
*/
Expand All @@ -299,11 +309,235 @@ object SerializedTableColumn {
val cv = batch.column(0)
cv match {
case serializedTableColumn: SerializedTableColumn
if serializedTableColumn.hostBuffer != null =>
if serializedTableColumn.hostBuffer != null =>
sum += serializedTableColumn.hostBuffer.getLength
case _ =>
}
}
sum
}
}

/**
* Serializer instance for serializing `ColumnarBatch`s for use during shuffle with
* [[KudoSerializer]]
*
* @param dataSize metric to track the size of the serialized data
* @param dataTypes data types of the columns in the batch
*/
private class KudoSerializerInstance(
val dataSize: GpuMetric,
val dataTypes: Array[DataType]) extends SerializerInstance {

private lazy val kudo = new KudoSerializer(GpuColumnVector.from(dataTypes))

override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] val dOut: DataOutputStream =
new DataOutputStream(new BufferedOutputStream(out))

override def writeValue[T: ClassTag](value: T): SerializationStream = {
val batch = value.asInstanceOf[ColumnarBatch]
val numColumns = batch.numCols()
val columns: Array[HostColumnVector] = new Array(numColumns)
withResource(new ArrayBuffer[AutoCloseable]()) { toClose =>
var startRow = 0
val numRows = batch.numRows()
if (batch.numCols() > 0) {
val firstCol = batch.column(0)
if (firstCol.isInstanceOf[SlicedGpuColumnVector]) {
// We don't have control over ColumnarBatch to put in the slice, so we have to do it
// for each column. In this case we are using the first column.
startRow = firstCol.asInstanceOf[SlicedGpuColumnVector].getStart
for (i <- 0 until numColumns) {
columns(i) = batch.column(i).asInstanceOf[SlicedGpuColumnVector].getBase
}
} else {
for (i <- 0 until numColumns) {
batch.column(i) match {
case gpu: GpuColumnVector =>
val cpu = gpu.copyToHost()
toClose += cpu
columns(i) = cpu.getBase
case cpu: RapidsHostColumnVector =>
columns(i) = cpu.getBase
}
}
}

withResource(new NvtxRange("Serialize Batch", NvtxColor.YELLOW)) { _ =>
dataSize += kudo.writeToStream(columns, dOut, startRow, numRows)
}
} else {
withResource(new NvtxRange("Serialize Row Only Batch", NvtxColor.YELLOW)) { _ =>
dataSize += KudoSerializer.writeRowCountToStream(dOut, numRows)
}
}
this
}
}

override def writeKey[T: ClassTag](key: T): SerializationStream = {
// The key is only needed on the map side when computing partition ids. It does not need to
// be shuffled.
assert(null == key || key.isInstanceOf[Int])
this
}

override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = {
// This method is never called by shuffle code.
throw new UnsupportedOperationException
}

override def writeObject[T: ClassTag](t: T): SerializationStream = {
// This method is never called by shuffle code.
throw new UnsupportedOperationException
}

override def flush(): Unit = {
dOut.flush()
}

override def close(): Unit = {
dOut.close()
}
}

override def deserializeStream(in: InputStream): DeserializationStream = {
new DeserializationStream {
private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in))

override def asKeyValueIterator: Iterator[(Int, ColumnarBatch)] = {
new KudoSerializedBatchIterator(dIn)
}

override def asIterator: Iterator[Any] = {
// This method is never called by shuffle code.
throw new UnsupportedOperationException
}

override def readKey[T]()(implicit classType: ClassTag[T]): T = {
// We skipped serialization of the key in writeKey(), so just return a dummy value since
// this is going to be discarded anyways.
null.asInstanceOf[T]
}

override def readValue[T]()(implicit classType: ClassTag[T]): T = {
// This method should never be called by shuffle code.
throw new UnsupportedOperationException
}

override def readObject[T]()(implicit classType: ClassTag[T]): T = {
// This method is never called by shuffle code.
throw new UnsupportedOperationException
}

override def close(): Unit = {
dIn.close()
}
}
}

// These methods are never called by shuffle code.
override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException

override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
throw new UnsupportedOperationException

override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
throw new UnsupportedOperationException
}

/**
* A special `ColumnVector` that describes a serialized table read from shuffle using
* [[KudoSerializer]].
*
* This appears in a `ColumnarBatch` to pass serialized tables to [[GpuShuffleCoalesceExec]]
* which should always appear in the query plan immediately after a shuffle.
*/
case class KudoSerializedTableColumn(kudoTable: KudoTable) extends GpuColumnVectorBase(NullType) {
override def close(): Unit = {
if (kudoTable != null) {
kudoTable.close()
}
}

override def hasNull: Boolean = throw new IllegalStateException("should not be called")

override def numNulls(): Int = throw new IllegalStateException("should not be called")
}

object KudoSerializedTableColumn {
/**
* Build a `ColumnarBatch` consisting of a single [[KudoSerializedTableColumn]] describing
* the specified serialized table.
*
* @param kudoTable Serialized kudo table.
* @return columnar batch to be passed to [[GpuShuffleCoalesceExec]]
*/
def from(kudoTable: KudoTable): ColumnarBatch = {
val column = new KudoSerializedTableColumn(kudoTable)
new ColumnarBatch(Array(column), kudoTable.getHeader.getNumRows)
}

def getMemoryUsed(batch: ColumnarBatch): Long = {
if (batch.numCols == 1) {
val cv = batch.column(0)
cv match {
case KudoSerializedTableColumn(kudoTable: KudoTable) =>
Option(kudoTable.getBuffer).map(_.getLength).getOrElse(0)
case _ => 0L
}
} else {
0L
}
}
}

class KudoSerializedBatchIterator(dIn: DataInputStream)
extends Iterator[(Int, ColumnarBatch)] {
private[this] var nextBatch: Option[ColumnarBatch] = None
private[this] var streamClosed: Boolean = false

// Don't install the callback if in a unit test
Option(TaskContext.get()).foreach { tc =>
onTaskCompletion(tc) {
nextBatch.foreach(_.close())
nextBatch = None
dIn.close()
}
}

def tryReadNext() = {
if (!streamClosed) {
withResource(new NvtxRange("Read Kudo Table", NvtxColor.YELLOW)) { _ =>
val kudoTable = KudoTable.from(dIn)
if (kudoTable.isPresent) {
nextBatch = Some(KudoSerializedTableColumn.from(kudoTable.get()))
} else {
dIn.close()
streamClosed = true
nextBatch = None
}
}
}
}

override def hasNext: Boolean = {
nextBatch match {
case Some(_) => true
case None =>
tryReadNext()
nextBatch.isDefined
}
}

override def next(): (Int, ColumnarBatch) = {
if (hasNext) {
val ret = nextBatch.get
nextBatch = None
(0, ret)
} else {
throw new NoSuchElementException("Walked off of the end...")
}
}
}
Loading

0 comments on commit 54b8dad

Please sign in to comment.