Skip to content

Commit

Permalink
arrow zerocopy for read and write in object store
Browse files Browse the repository at this point in the history
  • Loading branch information
Deegue committed May 25, 2023
1 parent 8ff162a commit 64652f4
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.call.ActorCreator;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.runtime.object.ObjectRefImpl;

import java.util.Map;
import java.util.List;

import io.ray.api.placementgroup.PlacementGroup;
import io.ray.runtime.object.ObjectRefImpl;
import org.apache.spark.executor.RayDPExecutor;

public class RayExecutorUtils {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.List;
import scala.collection.JavaConverters._

import io.ray.runtime.generated.Common.Address
import org.apache.arrow.vector.VectorSchemaRoot

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.api.java.JavaSparkContext
Expand All @@ -37,15 +38,15 @@ class RayDatasetRDD(
jsc: JavaSparkContext,
@transient val objectIds: List[Array[Byte]],
locations: List[Array[Byte]])
extends RDD[Array[Byte]](jsc.sc, Nil) {
extends RDD[VectorSchemaRoot](jsc.sc, Nil) {

override def getPartitions: Array[Partition] = {
objectIds.asScala.zipWithIndex.map { case (k, i) =>
new RayDatasetRDDPartition(k, i).asInstanceOf[Partition]
}.toArray
}

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
override def compute(split: Partition, context: TaskContext): Iterator[VectorSchemaRoot] = {
val ref = split.asInstanceOf[RayDatasetRDDPartition].ref
ObjectStoreReader.getBatchesFromStream(ref)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,26 @@
package org.apache.spark.sql.raydp

import java.io.ByteArrayInputStream
import java.nio.ByteBuffer
import java.nio.channels.{Channels, ReadableByteChannel}
import java.util.List

import scala.collection.JavaConverters._

import com.intel.raydp.shims.SparkShimLoader
import org.apache.arrow.vector.VectorSchemaRoot

import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.raydp.RayDPUtils
import org.apache.spark.rdd.{RayDatasetRDD, RayObjectRefRDD}
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}

object ObjectStoreReader {
def createRayObjectRefDF(
Expand All @@ -40,17 +48,56 @@ object ObjectStoreReader {
spark.createDataFrame(rdd, schema)
}

def fromRootIterator(
arrowRootIter: Iterator[VectorSchemaRoot],
schema: StructType,
timeZoneId: String): Iterator[InternalRow] = {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)

new Iterator[InternalRow] {
private var rowIter = if (arrowRootIter.hasNext) nextBatch() else Iterator.empty

override def hasNext: Boolean = rowIter.hasNext || {
if (arrowRootIter.hasNext) {
rowIter = nextBatch()
true
} else {
false
}
}

override def next(): InternalRow = rowIter.next()

private def nextBatch(): Iterator[InternalRow] = {
val root = arrowRootIter.next()
val columns = root.getFieldVectors.asScala.map { vector =>
new ArrowColumnVector(vector).asInstanceOf[ColumnVector]
}.toArray

val batch = new ColumnarBatch(columns)
batch.setNumRows(root.getRowCount)
root.close()
batch.rowIterator().asScala
}
}
}

def RayDatasetToDataFrame(
sparkSession: SparkSession,
rdd: RayDatasetRDD,
schema: String): DataFrame = {
SparkShimLoader.getSparkShims.toDataFrame(JavaRDD.fromRDD(rdd), schema, sparkSession)
schemaString: String): DataFrame = {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
val sqlContext = new SQLContext(sparkSession)
val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
val resultRDD = JavaRDD.fromRDD(rdd).rdd.mapPartitions { it =>
fromRootIterator(it, schema, timeZoneId)
}
sqlContext.internalCreateDataFrame(resultRDD.setName("arrow"), schema)
}

def getBatchesFromStream(
ref: Array[Byte]): Iterator[Array[Byte]] = {
val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]])
ArrowConverters.getBatchesFromStream(
Channels.newChannel(new ByteArrayInputStream(objectRef.get)))
ref: Array[Byte]): Iterator[VectorSchemaRoot] = {
val objectRef = RayDPUtils.readBinary(ref, classOf[VectorSchemaRoot])
Iterator[VectorSchemaRoot](objectRef.get)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.raydp


import java.io.ByteArrayOutputStream
import java.util.{List, UUID}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
Expand Down Expand Up @@ -61,17 +60,16 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
val uuid: UUID = ObjectStoreWriter.dfToId.getOrElseUpdate(df, UUID.randomUUID())

def writeToRay(
data: Array[Byte],
root: VectorSchemaRoot,
numRecords: Int,
queue: ObjectRefHolder.Queue,
ownerName: String): RecordBatch = {

var objectRef: ObjectRef[Array[Byte]] = null
var objectRef: ObjectRef[VectorSchemaRoot] = null
if (ownerName == "") {
objectRef = Ray.put(data)
objectRef = Ray.put(root)
} else {
var dataOwner: PyActorHandle = Ray.getActor(ownerName).get()
objectRef = Ray.put(data, dataOwner)
objectRef = Ray.put(root, dataOwner)
}

// add the objectRef to the objectRefHolder to avoid reference GC
Expand Down Expand Up @@ -111,21 +109,15 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val results = new ArrayBuffer[RecordBatch]()

val byteOut = new ByteArrayOutputStream()
val arrowWriter = ArrowWriter.create(root)
var numRecords: Int = 0

Utils.tryWithSafeFinally {
while (batchIter.hasNext) {
// reset the state
numRecords = 0
byteOut.reset()
arrowWriter.reset()

// write out the schema meta data
val writer = new ArrowStreamWriter(root, null, byteOut)
writer.start()

// get the next record batch
val nextBatch = batchIter.next()

Expand All @@ -136,19 +128,11 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {

// set the write record count
arrowWriter.finish()
// write out the record batch to the underlying out
writer.writeBatch()

// get the wrote ByteArray and save to Ray ObjectStore
val byteArray = byteOut.toByteArray
results += writeToRay(byteArray, numRecords, queue, ownerName)
// end writes footer to the output stream and doesn't clean any resources.
// It could throw exception if the output stream is closed, so it should be
// in the try block.
writer.end()

// write and schema root directly and save to Ray ObjectStore
results += writeToRay(root, numRecords, queue, ownerName)
}
arrowWriter.reset()
byteOut.close()
} {
// If we close root and allocator in TaskCompletionListener, there could be a race
// condition where the writer thread keeps writing to the VectorSchemaRoot while
Expand All @@ -173,7 +157,7 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
/**
* For test.
*/
def getRandomRef(): List[Array[Byte]] = {
def getRandomRef(): List[VectorSchemaRoot] = {

df.queryExecution.toRdd.mapPartitions { _ =>
Iterator(ObjectRefHolder.getRandom(uuid))
Expand Down Expand Up @@ -233,7 +217,7 @@ object ObjectStoreWriter {
var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray
val numExecutors = executorIds.length
val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME)
.get.asInstanceOf[ActorHandle[RayAppMaster]]
.get.asInstanceOf[ActorHandle[RayAppMaster]]
val restartedExecutors = RayAppMasterUtils.getRestartedExecutors(appMasterHandle)
// Check if there is any restarted executors
if (!restartedExecutors.isEmpty) {
Expand All @@ -251,8 +235,8 @@ object ObjectStoreWriter {
val refs = new Array[ObjectRef[Array[Byte]]](numPartitions)
val handles = executorIds.map {id =>
Ray.getActor("raydp-executor-" + id)
.get
.asInstanceOf[ActorHandle[RayDPExecutor]]
.get
.asInstanceOf[ActorHandle[RayDPExecutor]]
}
val handlesMap = (executorIds zip handles).toMap
val locations = RayExecutorUtils.getBlockLocations(
Expand All @@ -261,18 +245,15 @@ object ObjectStoreWriter {
// TODO use getPreferredLocs, but we don't have a host ip to actor table now
refs(i) = RayExecutorUtils.getRDDPartition(
handlesMap(locations(i)), rdd.id, i, schema, driverAgentUrl)
queue.add(refs(i))
}
for (i <- 0 until numPartitions) {
queue.add(RayDPUtils.readBinary(refs(i).get(), classOf[VectorSchemaRoot]))
results(i) = RayDPUtils.convert(refs(i)).getId.getBytes
}
results
}

}

object ObjectRefHolder {
type Queue = ConcurrentLinkedQueue[ObjectRef[Array[Byte]]]
type Queue = ConcurrentLinkedQueue[ObjectRef[VectorSchemaRoot]]
private val dfToQueue = new ConcurrentHashMap[UUID, Queue]()

def getQueue(df: UUID): Queue = {
Expand All @@ -297,7 +278,7 @@ object ObjectRefHolder {
queue.size()
}

def getRandom(df: UUID): Array[Byte] = {
def getRandom(df: UUID): VectorSchemaRoot = {
val queue = checkQueueExists(df)
val ref = RayDPUtils.convert(queue.peek())
ref.get()
Expand Down
11 changes: 2 additions & 9 deletions python/raydp/spark/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _convert_blocks_to_dataframe(blocks):
return df

def _convert_by_rdd(spark: sql.SparkSession,
blocks: Dataset,
blocks: List[ObjectRef],
locations: List[bytes],
schema: StructType) -> DataFrame:
object_ids = [block.binary() for block in blocks]
Expand Down Expand Up @@ -269,14 +269,7 @@ def ray_dataset_to_spark_dataframe(spark: sql.SparkSession,
schema = StructType()
for field in arrow_schema:
schema.add(field.name, from_arrow_type(field.type), nullable=field.nullable)
#TODO how to branch on type of block?
sample = ray.get(blocks[0])
if isinstance(sample, bytes):
return _convert_by_rdd(spark, blocks, locations, schema)
elif isinstance(sample, pa.Table):
return _convert_by_udf(spark, blocks, locations, schema)
else:
raise RuntimeError("ray.to_spark only supports arrow type blocks")
return _convert_by_rdd(spark, blocks, locations, schema)

if HAS_MLDATASET:
class RecordBatch(_SourceShard):
Expand Down

0 comments on commit 64652f4

Please sign in to comment.